mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
Enable CUDA graphs for MoE models + GPT-OSS support (#689)
* gmp-oss: common * gpt-oss: attnetion sinks, swiglu_oai * gpt-oss: WIP llama Model loads and runs (CPU only), but PPL is much to high (~1500 for 1st batch vs ~200 in mainline). Is it because of SWA, because of vocab, or did I introduce a bug somewhere? * gpt-oss: CPU seems to be working It was the SWA thta was missing in the previous commit. There are issues with EOG tokens, so this still needs to be added. * CUDA: ADD_ID Just a copy from mainline * gpt-oss: Seems to be working on CUDA * gpt-oss: add sinks to the attn-vec kernels * CUDA: add head size of 64 to new mma Haven't turned it on yet, but observe slightly better PP and slightly worse TG performance with that. * gpt-oss: add ability to use -fmoe (only CUDA for now) * Move row sums to the write place * Add sinks to iqk flash attention * gpt_oss: Implement -fmoe on the CPU * Simdify swiglu_oai Turning it off for now as performance becomes more variable, so perhaps I'm running into thermal trottling imore often because of making the CPU work too hard. * llama: factor out model loader * Builds successfully * It runs, but mmap does not work * Fix llama_mmap so mmap works * Minor * Fix CUDA after latest changes * Attempt to use CUDA graphs with MoE models - not working * CUDA graphs WIP - still not working * CUDA graphs - seems to be working Likely not all MLA variants are working. I no longer remember why I added the q8_0 cpy that transposes the tensor, but if really needed, this is now missing. Also missing is q6_0. * Make q8_0 cache work for DeepSeek models with CUDA graphs * cuda: cpy for q6_0 * Fix llama_mmap on non-Linux platforms * Adding forgotten file * Iterating on Windows build failures * cuda: re-add q8_0 -> q8_0 transpose so mla = 2 can be used with CUDA graphs and q8_0 cache. * Disable graphs without -fmoe * Minor * Turn graphs on by default --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -132,7 +132,7 @@ set (GGML_CUDA_MIN_BATCH_OFFLOAD "32" CACHE STRING
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
|
||||
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ON)
|
||||
|
||||
option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON)
|
||||
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)
|
||||
|
||||
@@ -325,6 +325,16 @@
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
#define GGML_TENSOR_TERNARY_OP_LOCALS \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
||||
@@ -571,6 +581,7 @@ extern "C" {
|
||||
|
||||
GGML_OP_DUP,
|
||||
GGML_OP_ADD,
|
||||
GGML_OP_ADD_ID,
|
||||
GGML_OP_ADD1,
|
||||
GGML_OP_ACC,
|
||||
GGML_OP_SUB,
|
||||
@@ -674,6 +685,7 @@ extern "C" {
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_SWIGLU,
|
||||
GGML_UNARY_OP_SWIGLU_OAI,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
@@ -1028,6 +1040,13 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
enum ggml_type type);
|
||||
|
||||
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
|
||||
GGML_API struct ggml_tensor * ggml_add_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -1268,6 +1287,13 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_swiglu_oai(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float alpha,
|
||||
float limit);
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
GGML_API struct ggml_tensor * ggml_silu_back(
|
||||
@@ -1370,6 +1396,16 @@ extern "C" {
|
||||
struct ggml_tensor * ids,
|
||||
enum ggml_unary_op op);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_moe_up_gate_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a_up,
|
||||
struct ggml_tensor * a_gate,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids,
|
||||
struct ggml_tensor * a_up_b,
|
||||
struct ggml_tensor * a_gate_b,
|
||||
enum ggml_unary_op op);
|
||||
|
||||
// A: m columns, n rows,
|
||||
// B: p columns, n rows,
|
||||
// result is m columns, p rows
|
||||
@@ -1662,6 +1698,11 @@ extern "C" {
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API void ggml_soft_max_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@@ -1998,6 +2039,10 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
GGML_API void ggml_flash_attn_ext_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
|
||||
@@ -37,6 +37,8 @@
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/add-id.cuh"
|
||||
#include "ggml-cuda/graph.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -49,6 +51,7 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
@@ -77,6 +80,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback,
|
||||
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
|
||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
|
||||
@@ -444,6 +448,35 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
static std::mutex ggml_cuda_lock;
|
||||
static std::condition_variable ggml_cuda_lock_cv;
|
||||
static std::atomic<int> ggml_cuda_lock_counter;
|
||||
|
||||
ggml_backend_cuda_context::ggml_backend_cuda_context(int device) :
|
||||
device(device), name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||
|
||||
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
|
||||
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
|
||||
|
||||
if (copy_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||
}
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||
if (streams[i][j] != nullptr) {
|
||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||
}
|
||||
}
|
||||
if (cublas_handles[i] != nullptr) {
|
||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// cuda buffer
|
||||
|
||||
struct ggml_backend_cuda_buffer_context {
|
||||
@@ -2220,6 +2253,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
|
||||
}
|
||||
}
|
||||
|
||||
//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
|
||||
//
|
||||
// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) {
|
||||
// dst[j] = src1[j] + src2[j % n_per_row];
|
||||
// }
|
||||
//}
|
||||
|
||||
static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
|
||||
|
||||
uint32_t row = blockIdx.x;
|
||||
const float * src1_row = src1 + row*n_per_row;
|
||||
float * dst_row = dst + row*n_per_row;
|
||||
|
||||
for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) {
|
||||
dst_row[j] = src1_row[j] + src2[j];
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
|
||||
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
|
||||
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
|
||||
@@ -2270,7 +2321,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
|
||||
return is_ser;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * ids = dst->src[2];
|
||||
@@ -2319,7 +2370,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
0, src0->ne[1], 1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return;
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] &&
|
||||
ggml_are_same_shape(src0, next->src[0]) &&
|
||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) {
|
||||
ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context;
|
||||
ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context;
|
||||
if (next_src0_ctx->device == device_id &&
|
||||
next_dst_ctx->device == device_id) {
|
||||
local_dst.data = next->data;
|
||||
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
|
||||
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
|
||||
0, src0->ne[1], 1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2356,7 +2425,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
if (ne12 == 1) {
|
||||
if (false && ne12 == 1) {
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
const char * ids_dev = (const char *) ids->data;
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
@@ -2442,6 +2511,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
|
||||
@@ -2470,6 +2540,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
src0_2_ctx->device == device_id &&
|
||||
src1_ctx->device == device_id &&
|
||||
dst_ctx->device == device_id) {
|
||||
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
|
||||
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
|
||||
// Fast TG path
|
||||
const int64_t n_ids = ids->ne[0];
|
||||
auto stream = ctx.stream(device_id, 0);
|
||||
@@ -2505,12 +2577,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (dst->src[4]) {
|
||||
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data,
|
||||
(const int32_t *)ids->data, (float *)local_dst.data,
|
||||
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
|
||||
local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream);
|
||||
}
|
||||
|
||||
local_dst.data = dst_gate_contiguous.get();
|
||||
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
|
||||
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
|
||||
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (dst->src[5]) {
|
||||
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data,
|
||||
(const int32_t *)ids->data, (float *)local_dst.data,
|
||||
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
|
||||
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
|
||||
}
|
||||
|
||||
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
|
||||
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
|
||||
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
|
||||
@@ -2518,8 +2604,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
ggml_backend_buffer_is_cuda(next->buffer) &&
|
||||
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
|
||||
@@ -2555,8 +2648,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
return true;
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
return false;
|
||||
}
|
||||
@@ -2624,7 +2723,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
final_src.nb[3] = final_src.nb[2];
|
||||
}
|
||||
|
||||
if (ne12 == 1) {
|
||||
if (false && ne12 == 1) {
|
||||
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
|
||||
if (fuse_down) {
|
||||
@@ -2761,6 +2860,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (dst->src[4]) {
|
||||
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
|
||||
(const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
dst_row.data = dst_gate_contiguous.get();
|
||||
if (use_quantized_src1) {
|
||||
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
|
||||
@@ -2770,8 +2877,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
|
||||
if (dst->src[5]) {
|
||||
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
|
||||
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
auto unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
|
||||
1.702f, 7.0f, stream);
|
||||
} else {
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
|
||||
(float *)dst_gate_contiguous.get());
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (fuse_down) {
|
||||
@@ -2851,6 +2974,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_ADD:
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
ggml_cuda_op_add_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
ggml_cuda_op_multi_add(ctx, dst);
|
||||
break;
|
||||
@@ -2877,6 +3003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_SWIGLU:
|
||||
ggml_cuda_op_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
ggml_cuda_op_swiglu_oai(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||
break;
|
||||
@@ -2938,7 +3067,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_cuda_mul_mat_id(ctx, dst);
|
||||
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
|
||||
break;
|
||||
case GGML_OP_MOE_FUSED_UP_GATE:
|
||||
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
|
||||
@@ -3119,6 +3248,105 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
|
||||
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n",
|
||||
__func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]);
|
||||
#endif
|
||||
}
|
||||
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) {
|
||||
auto src0_1 = node->src[0];
|
||||
auto src0_2 = node->src[1];
|
||||
auto src1 = node->src[2];
|
||||
if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 ||
|
||||
!ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) {
|
||||
use_cuda_graph = false;
|
||||
} else {
|
||||
if (i < cgraph->n_nodes-1) {
|
||||
auto next = cgraph->nodes[i+1];
|
||||
if (next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type)) {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD &&
|
||||
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
// to the device and accessed using indirection from CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (!ptr) {
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if (!use_cuda_graph) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
||||
}
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
graph_node_properties->node_address = node->data;
|
||||
graph_node_properties->node_op = node->op;
|
||||
@@ -3129,6 +3357,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
}
|
||||
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
@@ -3160,9 +3389,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
||||
cuda_graph_update_required = true;
|
||||
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = true;
|
||||
if (!cuda_graph_update_required) {
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
if (!has_matching_properties) {
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
|
||||
return cuda_graph_update_required;
|
||||
}
|
||||
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||
#endif
|
||||
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
// TODO
|
||||
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||
|
||||
while (!graph_evaluated_or_captured) {
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
#if 0
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifndef NDEBUG
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
//assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
||||
// ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(integrated);
|
||||
#endif // NDEBUG
|
||||
|
||||
bool skip_next = false;
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
|
||||
if (!ok) {
|
||||
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
if (skip_next) ++i;
|
||||
}
|
||||
}
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||
cuda_ctx->cuda_graph->graph = nullptr;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
|
||||
ggml_cuda_lock_cv.notify_all();
|
||||
}
|
||||
} else {
|
||||
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
update_cuda_graph_executable(cuda_ctx);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
graph_evaluated_or_captured = true;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
// Objects required for CUDA Graph
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||
// or previous graph capture failure.
|
||||
// Also disable for multi-gpu for now. TO DO investigate
|
||||
if (disable_cuda_graphs_due_to_env
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
||||
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
||||
use_cuda_graph = false;
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
||||
} else {
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
||||
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
// Start CUDA graph capture
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
/*
|
||||
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
@@ -3431,6 +3897,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
*/
|
||||
|
||||
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
@@ -3440,6 +3907,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_SWIGLU:
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
@@ -3629,6 +4097,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
|
||||
72
ggml/src/ggml-cuda/add-id.cu
Normal file
72
ggml/src/ggml-cuda/add-id.cu
Normal file
@@ -0,0 +1,72 @@
|
||||
#include "add-id.cuh"
|
||||
|
||||
static __global__ void add_id_kernel(
|
||||
const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||
int64_t ne0, int64_t ne1,
|
||||
size_t nb01, size_t nb02,
|
||||
size_t nb11,
|
||||
size_t nb21
|
||||
) {
|
||||
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
|
||||
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
|
||||
|
||||
const size_t nb1 = ne0 * sizeof(float);
|
||||
const size_t nb2 = ne1 * nb1;
|
||||
|
||||
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
|
||||
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
|
||||
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb20 == sizeof(int32_t));
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
const int32_t * src2_d = (const int32_t *)src2->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
int threads = std::min((int)ne00, 768); // cols
|
||||
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
|
||||
src0_d, src1_d, src2_d, dst_d,
|
||||
ne0, ne1,
|
||||
nb01, nb02,
|
||||
nb11,
|
||||
nb21
|
||||
);
|
||||
}
|
||||
|
||||
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02,
|
||||
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) {
|
||||
int threads = std::min((int)ne00, 768); // cols
|
||||
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
|
||||
add_id_kernel<<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, dst,
|
||||
ne0, ne1,
|
||||
nb01, nb02,
|
||||
nb11,
|
||||
nb21
|
||||
);
|
||||
}
|
||||
8
ggml/src/ggml-cuda/add-id.cuh
Normal file
8
ggml/src/ggml-cuda/add-id.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02,
|
||||
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream);
|
||||
|
||||
@@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
GGML_UNUSED(nbytes); \
|
||||
} while (0)
|
||||
#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
@@ -808,37 +825,7 @@ struct ggml_tensor_extra_gpu {
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
struct ggml_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
};
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
if (instance != nullptr) {
|
||||
CUDA_CHECK(cudaGraphExecDestroy(instance));
|
||||
}
|
||||
if (graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(graph));
|
||||
}
|
||||
}
|
||||
cudaGraph_t graph = nullptr;
|
||||
cudaGraphExec_t instance = nullptr;
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
std::vector<char **> updated_kernel_arg;
|
||||
#endif
|
||||
};
|
||||
struct ggml_cuda_graph;
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
@@ -850,26 +837,9 @@ struct ggml_backend_cuda_context {
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
explicit ggml_backend_cuda_context(int device);
|
||||
|
||||
~ggml_backend_cuda_context() {
|
||||
if (copy_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||
}
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||
if (streams[i][j] != nullptr) {
|
||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||
}
|
||||
}
|
||||
if (cublas_handles[i] != nullptr) {
|
||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
if (streams[device][stream] == nullptr) {
|
||||
|
||||
262
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
262
ggml/src/ggml-cuda/cpy-utils.cuh
Normal file
@@ -0,0 +1,262 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-common.h"
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
|
||||
if constexpr (std::is_same_v<src_t, dst_t>) {
|
||||
*dst = *src;
|
||||
} else {
|
||||
*dst = float(*src);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
y->qs[j] = xi0;
|
||||
y->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
|
||||
float vmin = FLT_MAX;
|
||||
float vmax = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; ++j) {
|
||||
const float v = x[j];
|
||||
if (v < vmin) vmin = v;
|
||||
if (v > vmax) vmax = v;
|
||||
}
|
||||
|
||||
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->dm.x = d;
|
||||
y->dm.y = vmin;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (x[0 + j] - vmin)*id;
|
||||
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
|
||||
|
||||
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
y->qs[j] = xi0;
|
||||
y->qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
memcpy(y->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
|
||||
float min = x[0];
|
||||
float max = x[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; ++j) {
|
||||
const float v = x[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->dm.x = d;
|
||||
y->dm.y = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (x[0 + j] - min)*id;
|
||||
const float x1 = (x[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
memcpy(y->qh, &qh, sizeof(qh));
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = x[j];
|
||||
amax = fmaxf(amax, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = x[j]*id;
|
||||
y->qs[j] = roundf(x0);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; ++j) {
|
||||
const float v = x[j];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
float d = vmax / kvalues_iq4nl[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = x[0 + j]*id;
|
||||
const float x1 = x[QK4_NL/2 + j]*id;
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
|
||||
y->qs[j] = xi0 | (xi1 << 4);
|
||||
const float v0 = kvalues_iq4nl[xi0];
|
||||
const float v1 = kvalues_iq4nl[xi1];
|
||||
const float w0 = x[0 + j]*x[0 + j];
|
||||
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
}
|
||||
|
||||
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
|
||||
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ xi, block_q6_0 * __restrict__ y) {
|
||||
|
||||
float amax = 0.0f;
|
||||
float vmax = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK6_0; ++j) {
|
||||
const float v = xi[j];
|
||||
const float av = fabsf(xi[j]);
|
||||
if (amax < av) {
|
||||
amax = av;
|
||||
vmax = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = vmax / -32;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y->d = d;
|
||||
memset(y->qh, 0, QK6_0/4);
|
||||
|
||||
for (int j = 0; j < QK6_0/2; ++j) {
|
||||
const float x0 = xi[0 + j]*id;
|
||||
const float x1 = xi[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
|
||||
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
|
||||
|
||||
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
|
||||
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper functions for cpy.cu compatibility
|
||||
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,11 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_CPY_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||
|
||||
@@ -86,6 +86,24 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const block_q6_0 * x = (const block_q6_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8);
|
||||
v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4));
|
||||
v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2));
|
||||
|
||||
#ifdef GGML_CUDA_F16
|
||||
v = __hsub2(v, {32.0f, 32.0f});
|
||||
v = __hmul2(v, {d, d});
|
||||
#else
|
||||
v.x = (v.x - 32.0f) * d;
|
||||
v.y = (v.y - 32.0f) * d;
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -747,6 +748,7 @@ void launch_fattn(
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -837,6 +839,7 @@ void launch_fattn(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *) sinks->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1008,7 +1011,8 @@ void launch_fattn_mma(
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -1162,6 +1166,7 @@ void launch_fattn_mma(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
|
||||
@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half2 * const __restrict__ mask_h2,
|
||||
const float * const __restrict__ sinks_f,
|
||||
float2 * const __restrict__ dstk,
|
||||
float2 * const __restrict__ dstk_fixup,
|
||||
const float scale,
|
||||
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new;
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||
}
|
||||
|
||||
if (ntiles == 1) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write VKQ accumulators to shared memory in column-major format.
|
||||
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// Also for np > 1 the combination is done via these values in shared memory.
|
||||
@@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
@@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
@@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
|
||||
@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
|
||||
// Perhaps the 256 head size needs a closer look
|
||||
// to see if this implementation is better.
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 64, 64> {
|
||||
// static constexpr int nbatch_fa = 64;
|
||||
// static constexpr int nwarps_max = 4;
|
||||
// static constexpr bool Q_in_reg = true;
|
||||
// static constexpr int nstages_target = 2;
|
||||
//
|
||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
// return 32;
|
||||
// }
|
||||
//};
|
||||
template <>
|
||||
struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config< 80, 80> {
|
||||
@@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
} else {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
if constexpr (ncols2 > 1 || mask_h2) {
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
||||
}
|
||||
}
|
||||
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||
|
||||
if constexpr (ntiles == 1) {
|
||||
if constexpr (ncols2 > 1 || mask_h2) {
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
||||
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
||||
@@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half2 * const __restrict__ mask_h2,
|
||||
const float * const __restrict__ sinks_f,
|
||||
float2 * const __restrict__ dstk,
|
||||
float2 * const __restrict__ dstk_fixup,
|
||||
const float scale,
|
||||
@@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new;
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||
}
|
||||
|
||||
if (ntiles == 1) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||
{
|
||||
@@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f);
|
||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
||||
@@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
@@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
@@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma(
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
@@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
@@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (K->ne[0] == 64 && V->ne[0] == 64) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
V += nb22*(blockIdx.y / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
@@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf) {
|
||||
const half sink = __float2half(sinksf[blockIdx.y]);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
|
||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
@@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (sinksf) {
|
||||
const float sink = sinksf[blockIdx.y];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
||||
kqmax[j] = kqmax_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j]);
|
||||
kqsum[j] = kqsum[j]*KQ_max_scale;
|
||||
|
||||
if (tid == 0) {
|
||||
kqsum[j] += val;
|
||||
}
|
||||
|
||||
VKQ[j] *= KQ_max_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
// TODO: attention sinks !!!
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
@@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_K = nb11 / sizeof(half);
|
||||
|
||||
@@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
// As mentioned above, the new new MMA is slower than then the new MMA.
|
||||
// As mentioned above, the new-new MMA is slower then the new MMA.
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
}
|
||||
|
||||
41
ggml/src/ggml-cuda/graph.cuh
Normal file
41
ggml/src/ggml-cuda/graph.cuh
Normal file
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
struct ggml_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
if (instance != nullptr) {
|
||||
CUDA_CHECK(cudaGraphExecDestroy(instance));
|
||||
}
|
||||
if (graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(graph));
|
||||
}
|
||||
}
|
||||
cudaGraph_t graph = nullptr;
|
||||
cudaGraphExec_t instance = nullptr;
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
bool use_cpy_indirection = false;
|
||||
std::vector<char *> cpy_dest_ptrs;
|
||||
char ** dest_ptrs_d;
|
||||
int dest_ptrs_size = 0;
|
||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||
// relative to other cpy nodes.
|
||||
int graph_cpynode_index = -1;
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
||||
static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
@@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
@@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
@@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
}
|
||||
}
|
||||
|
||||
struct soft_max_params {
|
||||
|
||||
int64_t nheads;
|
||||
uint32_t n_head_log2;
|
||||
int64_t ncols;
|
||||
int64_t nrows_x;
|
||||
int64_t nrows_y;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
int64_t nb11;
|
||||
int64_t nb12;
|
||||
int64_t nb13;
|
||||
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
};
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int64_t i03 = blockIdx.z;
|
||||
const int64_t i02 = blockIdx.y;
|
||||
const int64_t i01 = blockIdx.x;
|
||||
|
||||
//TODO: noncontigous inputs/outputs
|
||||
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||
|
||||
const int64_t i11 = i01;
|
||||
const int64_t i12 = i02 % p.ne12;
|
||||
const int64_t i13 = i03 % p.ne13;
|
||||
|
||||
x += int64_t(rowx)*ncols;
|
||||
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||
|
||||
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = -INFINITY;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
max_val = buf_iw[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = expf(vals[col] - max_val);
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__syncthreads();
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tmp = buf_iw[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
if (sinks) {
|
||||
tmp += expf(sinks[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template<int... Ns, typename T>
|
||||
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
||||
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||
{
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
auto launch_kernel = [=](auto I) -> bool {
|
||||
constexpr int ncols = decltype(I)::value;
|
||||
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||
|
||||
if (p.ncols == ncols) {
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, sinks, dst, p);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// unary fold over launch_kernel
|
||||
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//default case
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
|
||||
soft_max_params params = {};
|
||||
params.nheads = src0->ne[2];
|
||||
params.n_head_log2 = n_head_log2;
|
||||
params.ncols = ne00;
|
||||
params.nrows_x = nrows_x;
|
||||
params.nrows_y = nrows_y;
|
||||
params.ne00 = src0->ne[0];
|
||||
params.ne01 = src0->ne[1];
|
||||
params.ne02 = src0->ne[2];
|
||||
params.ne03 = src0->ne[3];
|
||||
params.nb11 = nb11;
|
||||
params.nb12 = nb12;
|
||||
params.nb13 = nb13;
|
||||
params.ne12 = ne12;
|
||||
params.ne13 = ne13;
|
||||
params.scale = scale;
|
||||
params.max_bias = max_bias;
|
||||
params.m0 = m0;
|
||||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
||||
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
swiglu_oai_kernel<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
void * src0_d = src0->data;
|
||||
void * src1_d = src1 ? src1->data : src0->data;
|
||||
const int64_t src0_o = src0->nb[1];
|
||||
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
void * dst_d = dst->data;
|
||||
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||
GGML_ASSERT(src1->ne[0] == nc);
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
const float alpha = op_params[2];
|
||||
const float limit = op_params[3];
|
||||
|
||||
float * src0_p = (float *) src0_d;
|
||||
float * src1_p = (float *) src1_d;
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc,
|
||||
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||
}
|
||||
|
||||
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
|
||||
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||
swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream);
|
||||
}
|
||||
|
||||
@@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
|
||||
int64_t nelements, const float * x, const float * y, float * z);
|
||||
|
||||
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
|
||||
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream);
|
||||
|
||||
|
||||
509
ggml/src/ggml.c
509
ggml/src/ggml.c
@@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t
|
||||
|
||||
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
||||
|
||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
|
||||
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
|
||||
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
|
||||
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
|
||||
@@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
|
||||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
|
||||
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
|
||||
int i = 0;
|
||||
#if defined(__AVX2__)
|
||||
for (; i + 7 < n; i += 8) {
|
||||
__m256 vx = _mm256_loadu_ps(x + i);
|
||||
__m256 vy = _mm256_loadu_ps(y + i);
|
||||
__m256 vz = _mm256_add_ps(vx, vy);
|
||||
_mm256_storeu_ps(z + i, vz);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) z[i] = x[i] + y[i];
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
@@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
|
||||
"DUP",
|
||||
"ADD",
|
||||
"ADD_ID",
|
||||
"ADD1",
|
||||
"ACC",
|
||||
"SUB",
|
||||
@@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
||||
"x",
|
||||
"x+y",
|
||||
"x[i]+y",
|
||||
"x+y",
|
||||
"view(x,nb,offset)+=y->x",
|
||||
"x-y",
|
||||
@@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"HARDSWISH",
|
||||
"HARDSIGMOID",
|
||||
"SWIGLU",
|
||||
"SWIGLU_OAI",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
@@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast(
|
||||
return ggml_add_cast_impl(ctx, a, b, type);
|
||||
}
|
||||
|
||||
// ggml_add_id
|
||||
|
||||
struct ggml_tensor * ggml_add_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids) {
|
||||
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||
GGML_ASSERT(a->ne[1] == ids->ne[0]);
|
||||
GGML_ASSERT(a->ne[2] == ids->ne[1]);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_ADD_ID;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = ids;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_add1
|
||||
|
||||
static struct ggml_tensor * ggml_add1_impl(
|
||||
@@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_swiglu_oai(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float alpha,
|
||||
float limit) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||
if (b) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(b));
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
GGML_ASSERT(a->type == b->type);
|
||||
}
|
||||
|
||||
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
|
||||
|
||||
result->op = GGML_OP_UNARY;
|
||||
result->grad = NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI);
|
||||
ggml_set_op_params_f32(result, 2, alpha);
|
||||
ggml_set_op_params_f32(result, 3, limit);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_silu_back
|
||||
|
||||
struct ggml_tensor * ggml_silu_back(
|
||||
@@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate(
|
||||
result->src[1] = as_gate;
|
||||
result->src[2] = b;
|
||||
result->src[3] = ids;
|
||||
result->src[4] = NULL;
|
||||
result->src[5] = NULL;
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_moe_up_gate_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * as_up,
|
||||
struct ggml_tensor * as_gate,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * ids,
|
||||
struct ggml_tensor * as_up_b,
|
||||
struct ggml_tensor * as_gate_b,
|
||||
enum ggml_unary_op op) {
|
||||
|
||||
if (!as_up_b && !as_gate_b) {
|
||||
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
|
||||
}
|
||||
|
||||
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
|
||||
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
|
||||
if (as_up_b) {
|
||||
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
|
||||
}
|
||||
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
|
||||
if (as_gate_b) {
|
||||
result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids);
|
||||
}
|
||||
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
|
||||
}
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(as_up));
|
||||
GGML_ASSERT(!ggml_is_transposed(as_gate));
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
|
||||
GGML_ASSERT(b->ne[3] == 1); // b is 3d
|
||||
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
|
||||
GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
|
||||
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
|
||||
|
||||
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
|
||||
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
|
||||
bool is_node = false;
|
||||
|
||||
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_MOE_FUSED_UP_GATE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = as_up;
|
||||
result->src[1] = as_gate;
|
||||
result->src[2] = b;
|
||||
result->src[3] = ids;
|
||||
result->src[4] = as_up_b;
|
||||
result->src[5] = as_gate_b;
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||
|
||||
@@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext(
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||
}
|
||||
|
||||
void ggml_soft_max_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks) {
|
||||
if (!sinks) {
|
||||
a->src[2] = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
|
||||
GGML_ASSERT(a->src[2] == NULL);
|
||||
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||
|
||||
a->src[2] = sinks;
|
||||
}
|
||||
|
||||
// ggml_soft_max_back
|
||||
|
||||
static struct ggml_tensor * ggml_soft_max_back_impl(
|
||||
@@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec(
|
||||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||
}
|
||||
|
||||
void ggml_flash_attn_ext_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks) {
|
||||
if (!sinks) {
|
||||
a->src[4] = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_ASSERT(a->src[4] == NULL);
|
||||
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
|
||||
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
|
||||
|
||||
a->src[4] = sinks;
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_back(
|
||||
@@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_add_id
|
||||
|
||||
static void ggml_compute_forward_add_id_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_TERNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
// src1 indices
|
||||
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
|
||||
|
||||
GGML_ASSERT(i11 >= 0 && i11 < ne11);
|
||||
|
||||
ggml_vec_add_f32(ne0,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
||||
(float *) ((char *) src1->data + i11*nb11));
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_id(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_add_id_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_add1
|
||||
|
||||
static void ggml_compute_forward_add1_f32(
|
||||
@@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_swiglu_oai
|
||||
|
||||
static void ggml_compute_forward_swiglu_oai_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
char * src0_d = (char *) src0->data;
|
||||
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||
const size_t src0_o = src0->nb[1];
|
||||
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||
|
||||
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
|
||||
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||
const float limit = ggml_get_op_params_f32(dst, 3);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = MIN(src0_p[k], limit);
|
||||
const float y = MAX(MIN(src1_p[k], limit), -limit);
|
||||
const float out_glu = x / (1.f + expf(alpha * (-x)));
|
||||
dst_p[k] = out_glu * (y + 1.f);
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = dst_p[k];
|
||||
GGML_UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_swiglu_oai(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_swiglu_oai_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_fused_mul_unary
|
||||
|
||||
static void ggml_compute_forward_fused_mul_unary_f32(
|
||||
@@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
|
||||
const struct ggml_tensor * src1 = dst->src[2];
|
||||
const struct ggml_tensor * ids = dst->src[3];
|
||||
const struct ggml_tensor * up_b = dst->src[4];
|
||||
const struct ggml_tensor * gate_b = dst->src[5];
|
||||
const struct ggml_tensor * src0_1 = dst->src[0];
|
||||
const struct ggml_tensor * src0_2 = dst->src[1];
|
||||
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
|
||||
@@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
const size_t nb41 = up_b ? up_b->nb[1] : 0;
|
||||
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
@@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
|
||||
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
|
||||
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
|
||||
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
|
||||
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
@@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
|
||||
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
|
||||
type, src0_1_cur, src0_2_cur, nb01,
|
||||
vec_dot_type, (const char *)wdata, row_size,
|
||||
up_b_cur, gate_b_cur,
|
||||
(float *)dst->data, nb1, nb2,
|
||||
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
|
||||
|
||||
@@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
assert(ggml_is_contiguous(dst));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
@@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||
|
||||
// TODO: is this supposed to be ceil instead of floor?
|
||||
@@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
// ALiBi
|
||||
const uint32_t h = (i1/ne01)%ne02; // head
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
// sinks
|
||||
const float * sk = src2 ? (float *)((char *) src2->data) : NULL;
|
||||
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const int64_t i11 = i01;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int64_t i13 = i03%ne13;
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
// ALiBi
|
||||
const uint32_t h = i02; // head
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
if (mp_f32) {
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||
|
||||
ggml_vec_cpy_f32 (ne00, wp, sp);
|
||||
ggml_vec_scale_f32(ne00, wp, scale);
|
||||
if (mp_f32) {
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp[i] += slope*mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*mp_f32[i];
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
assert(!isnan(wp[i]));
|
||||
}
|
||||
#endif
|
||||
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(ne00, &max, wp);
|
||||
|
||||
// if we have sinks, make a correction as if they were included in the softmax
|
||||
if (sk) {
|
||||
max = MAX(max, sk[i02]);
|
||||
}
|
||||
|
||||
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
||||
assert(sum > 0.0);
|
||||
|
||||
if (sk) {
|
||||
sum += (ggml_float) expf(sk[i02] - max);
|
||||
}
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(ne00, dp, sum);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
assert(!isnan(dp[i]));
|
||||
assert(!isinf(dp[i]));
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
//#ifndef NDEBUG
|
||||
// for (int i = 0; i < nc; ++i) {
|
||||
// //printf("p[%d] = %f\n", i, p[i]);
|
||||
// assert(!isnan(wp[i]));
|
||||
// }
|
||||
//#endif
|
||||
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(nc, &max, wp);
|
||||
|
||||
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
||||
//assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(nc, dp, sum);
|
||||
|
||||
//#ifndef NDEBUG
|
||||
// for (int i = 0; i < nc; ++i) {
|
||||
// assert(!isnan(dp[i]));
|
||||
// assert(!isinf(dp[i]));
|
||||
// }
|
||||
//#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ggml_compute_forward_soft_max_back
|
||||
|
||||
static void ggml_compute_forward_soft_max_back_f32(
|
||||
@@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh(
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * q = dst->src[0];
|
||||
const struct ggml_tensor * k = dst->src[1];
|
||||
const struct ggml_tensor * v = dst->src[2];
|
||||
const struct ggml_tensor * mask = dst->src[3];
|
||||
const struct ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
@@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
#if GGML_USE_IQK_MULMAT
|
||||
// For now we do not implement sinks in the iqk FA implementation
|
||||
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
|
||||
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
|
||||
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
|
||||
@@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
dst->ne[2], dst->ne[1], dst->nb[1],
|
||||
k->type, v->type,
|
||||
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
|
||||
q->data, k->data, v->data, mask->data,
|
||||
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
|
||||
scale, softcap, (float *)dst->data,
|
||||
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
|
||||
|
||||
@@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
||||
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
||||
|
||||
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
||||
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
||||
|
||||
const int64_t Dkv = MAX(Dk, Dv);
|
||||
|
||||
// loop over n_batch and n_head
|
||||
@@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
ggml_vec_scale_f32(Dv, VKQ32, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
S = S*ms + vs;
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f/S;
|
||||
ggml_vec_scale_f32(Dv, VKQ32, S_inv);
|
||||
@@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * mask,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->op_params[3]) {
|
||||
case GGML_PREC_DEFAULT:
|
||||
case GGML_PREC_F32:
|
||||
{
|
||||
// uses F32 accumulators
|
||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||
ggml_compute_forward_flash_attn_ext_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_swiglu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
{
|
||||
ggml_compute_forward_swiglu_oai(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
{
|
||||
ggml_compute_forward_hardswish(params, dst);
|
||||
@@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_add(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
{
|
||||
ggml_compute_forward_add_id(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
ggml_compute_forward_add1(params, tensor);
|
||||
@@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
@@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
} break;
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
if (src0->grad) {
|
||||
@@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
@@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_SWIGLU:
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
if (ggml_is_quantized(node->src[0]->type)) {
|
||||
|
||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ namespace {
|
||||
template <int step_k, typename KHelper, typename VHelper>
|
||||
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv,
|
||||
const float * sinkf, float * M, float * S) {
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
if (nq1 == 0) return true;
|
||||
@@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
|
||||
};
|
||||
if (nq1 >= 16) {
|
||||
int n_step = nq1/16;
|
||||
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
|
||||
FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
|
||||
FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
|
||||
FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
|
||||
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
|
||||
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
|
||||
}
|
||||
|
||||
@@ -51,37 +52,37 @@ template <int step_k>
|
||||
inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
if (type_k == GGML_TYPE_Q8_0) {
|
||||
HelperQ80 kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q8_0_R8) {
|
||||
HelperQ80R8<576> kh((const char *)k, stride_k);
|
||||
HelperQ80 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
if (type_k == GGML_TYPE_Q6_0) {
|
||||
HelperQ60 kh((const char *)k, stride_k);
|
||||
HelperQ60 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
if (type_k == GGML_TYPE_Q8_KV) {
|
||||
HelperQ8KV<576> kh((const char *)k, stride_k);
|
||||
HelperQ8KV<512> vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
if (type_k == GGML_TYPE_F16) {
|
||||
HelperF16 kh((const char *)k, stride_k);
|
||||
HelperF16 vh((const char *)v, stride_v);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#ifdef __AVX512BF16__
|
||||
@@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
|
||||
HelperBF16<576, step_k> kh((const char *)k, stride_k);
|
||||
HelperBF16<512, step_k> vh((const char *)v, stride_v);
|
||||
if (nq1 % 8 == 0) {
|
||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
|
||||
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
} else {
|
||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
|
||||
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
return true;
|
||||
@@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) {
|
||||
}
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
|
||||
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, M, S);
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1141,10 +1141,25 @@ struct FlashQKV {
|
||||
}
|
||||
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const {
|
||||
inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
GGML_ASSERT(fms.S[j] > 0);
|
||||
auto norm = F16::set1(1/fms.S[j]);
|
||||
float S = fms.S[j];
|
||||
if (sinkf) {
|
||||
float s = *sinkf;
|
||||
if (s > fms.M[j]) {
|
||||
float m = expf(fms.M[j] - s);
|
||||
auto vm = F16::set1(m);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
auto Ri = R + F16::block_size*i;
|
||||
F16::store(Ri, F16::mul(vm, F16::load(Ri)));
|
||||
}
|
||||
S = S*m + 1;
|
||||
} else {
|
||||
S += expf(s - fms.M[j]);
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(S > 0);
|
||||
auto norm = F16::set1(1/S);
|
||||
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
auto r = F16::load(R + F16::block_size*i);
|
||||
@@ -1153,7 +1168,7 @@ struct FlashQKV {
|
||||
}
|
||||
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if (M && S) {
|
||||
std::memcpy(M, fms.M, nq1*sizeof(float));
|
||||
@@ -1173,7 +1188,7 @@ struct FlashQKV {
|
||||
} else {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < nq1; ++j) {
|
||||
normalize_and_store_1row(fms, j, R, qkv);
|
||||
normalize_and_store_1row(fms, j, R, qkv, sinkf);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
@@ -1181,7 +1196,7 @@ struct FlashQKV {
|
||||
}
|
||||
|
||||
template <typename FMS>
|
||||
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const {
|
||||
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
static_assert(q_step == FMS::q_step);
|
||||
if (M && S) {
|
||||
std::memcpy(M, fms.M, q_step*sizeof(float));
|
||||
@@ -1201,7 +1216,7 @@ struct FlashQKV {
|
||||
} else {
|
||||
auto R = qkv_cache;
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
normalize_and_store_1row(fms, j, R, qkv);
|
||||
normalize_and_store_1row(fms, j, R, qkv, sinkf);
|
||||
qkv += stride_qkv;
|
||||
R += D;
|
||||
}
|
||||
@@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
FlashMS<q_step, k_step>& fms,
|
||||
FlashQKV<Dv, q_step, k_step>& fqkv,
|
||||
const float * q, const char * mask, float * qkv,
|
||||
float * M, float * S) {
|
||||
const float * sinkf, float * M, float * S) {
|
||||
#ifdef __aarch64__
|
||||
float16_t q_f16[Dk*q_step];
|
||||
#endif
|
||||
@@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||
|
||||
q += q_step*stride_q;
|
||||
mask += q_step*stride_m;
|
||||
@@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
FlashMS<q_step, k_step>& fms,
|
||||
FlashQKV<Dv, q_step, k_step>& fqkv,
|
||||
const float * q, const char * mask, float * qkv,
|
||||
float * M, float * S, char * qptr) {
|
||||
const float * sinkf, float * M, float * S, char * qptr) {
|
||||
auto q8 = (typename KHelper::block_q8 *)qptr;
|
||||
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
|
||||
if (nq1 == q_step) {
|
||||
@@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
}
|
||||
#if FA_TIMING
|
||||
t1 = Perf::cur_time();
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||
perf.accum_nolock(3, t1);
|
||||
#else
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||
#endif
|
||||
|
||||
q += q_step*stride_q;
|
||||
@@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||
}
|
||||
#if FA_TIMING
|
||||
Perf::instance().add(perf);
|
||||
@@ -1504,7 +1519,7 @@ struct FlashAttn {
|
||||
static_assert(k_step%F16::block_size == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
|
||||
FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
|
||||
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
@@ -1533,7 +1548,7 @@ struct FlashAttn {
|
||||
HelperQ80R8<Dk> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||
return;
|
||||
|
||||
}
|
||||
@@ -1547,29 +1562,30 @@ struct FlashAttn {
|
||||
HelperQ8KVR8<Dk> khr4(nk1, kh);
|
||||
#endif
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
|
||||
|
||||
}
|
||||
else {
|
||||
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
|
||||
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8);
|
||||
}
|
||||
}
|
||||
else {
|
||||
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S);
|
||||
}
|
||||
}
|
||||
|
||||
FlashMS<q_step, k_step> fms;
|
||||
FlashQKV<Dv, q_step, k_step> fqkv;
|
||||
const float * sinkf;
|
||||
|
||||
};
|
||||
|
||||
@@ -1927,7 +1943,7 @@ struct FlashAttnBF16 {
|
||||
static_assert(k_step%32 == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
|
||||
FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
|
||||
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
@@ -1967,7 +1983,7 @@ struct FlashAttnBF16 {
|
||||
#if FA_TIMING
|
||||
t1 = Perf::cur_time();
|
||||
#endif
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
|
||||
#if FA_TIMING
|
||||
perf.accum_nolock(4, t1);
|
||||
#endif
|
||||
@@ -1990,7 +2006,7 @@ struct FlashAttnBF16 {
|
||||
vh.next_block(k_step);
|
||||
mr += k_step*sizeof(ggml_half);
|
||||
}
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
|
||||
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
|
||||
}
|
||||
#if FA_TIMING
|
||||
Perf::instance().add(perf);
|
||||
@@ -1999,12 +2015,14 @@ struct FlashAttnBF16 {
|
||||
|
||||
FlashMS<q_step, k_step> fms;
|
||||
FlashQKV<Dv, q_step, k_step> fqkv;
|
||||
const float * sinkf;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
|
||||
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv,
|
||||
const float * sinkf, float * M, float * S) {
|
||||
|
||||
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
|
||||
nq1 -= n;
|
||||
@@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
||||
if (nk1 >= 512) {
|
||||
if (nq1 >= 128) {
|
||||
int n_step = nq1/128;
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(128*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 64) {
|
||||
int n_step = nq1/64;
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(64*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 32) {
|
||||
int n_step = nq1/32;
|
||||
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(32*n_step)) return;
|
||||
}
|
||||
if (nq1 >= 16) {
|
||||
int n_step = nq1/16;
|
||||
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(16*n_step)) return;
|
||||
}
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
int n_step = nq1/8;
|
||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(8*n_step)) return;
|
||||
}
|
||||
else if (nq1 >= 4) {
|
||||
int n_step = nq1/4;
|
||||
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(4*n_step)) return;
|
||||
}
|
||||
else if (nq1 >= 2) {
|
||||
int n_step = nq1/2;
|
||||
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
if (update(2*n_step)) return;
|
||||
}
|
||||
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
|
||||
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
|
||||
@@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
||||
template <int Dk, int Dv, int k_step>
|
||||
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
HelperBF16<Dk, k_step> kh(k, stride_k);
|
||||
HelperBF16<Dv, k_step> vh(v, stride_v);
|
||||
if (nk1 >= 4096) {
|
||||
if (nq1 >= 64) {
|
||||
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
|
||||
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
else if (nq1 >= 16) {
|
||||
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
|
||||
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (nq1 >= 8) {
|
||||
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
|
||||
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
} else {
|
||||
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
|
||||
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
|
||||
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
|
||||
}
|
||||
}
|
||||
@@ -2096,43 +2114,43 @@ template <int Dk, int Dv, int k_step, typename KHelper>
|
||||
inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
|
||||
switch (type_v) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: {
|
||||
HelperBF16<Dv, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#endif
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<Dv> vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41 vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4nl vh(v, stride_v);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
|
||||
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#endif
|
||||
default: return false;
|
||||
@@ -2144,42 +2162,42 @@ template <int Dk, int Dv, int k_step>
|
||||
inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
|
||||
const float * q, const char * k, const char * v, const char * mask,
|
||||
float scale, float softcap, float * qkv, float * M, float * S) {
|
||||
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
|
||||
|
||||
bool result = false;
|
||||
switch (type_k) {
|
||||
case GGML_TYPE_F16: {
|
||||
HelperF16 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0: {
|
||||
HelperQ80 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0_R8: {
|
||||
HelperQ80R8<Dk> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#if GGML_IQK_FA_ALL_QUANTS
|
||||
case GGML_TYPE_Q8_KV: {
|
||||
HelperQ8KV<Dk> kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0: {
|
||||
HelperQ40 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1: {
|
||||
HelperQ41 kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4nl kh(k, stride_k);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
|
||||
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
|
||||
} break;
|
||||
#endif
|
||||
default: break;
|
||||
@@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
|
||||
const float * q, const void * k, const void * v, const void * mask,\
|
||||
float scale, float softcap,\
|
||||
float * qkv, float * M, float * S)
|
||||
float * qkv, const float * sinkf, float * M, float * S)
|
||||
|
||||
IQK_FA_CASE(iqk_fa_576_512);
|
||||
IQK_FA_CASE(iqk_fa_192_128);
|
||||
|
||||
@@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
@@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
|
||||
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
|
||||
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
|
||||
scale, softcap,
|
||||
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
|
||||
|
||||
@@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
|
||||
auto result_size = (Dv + 16)*rk2*sizeof(float);
|
||||
int gcd = simple_gcd(nek2, nth);
|
||||
if (false && gcd > 1) {
|
||||
int nth_g = nth/gcd;
|
||||
int ith_g = ith%nth_g;
|
||||
int nek1_32 = nek1/32;
|
||||
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
|
||||
int ith_mid = nth_g;
|
||||
if (nek1_pt*nth_g > nek1_32) {
|
||||
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
|
||||
}
|
||||
nek1_pt *= 32;
|
||||
int nek1_mid = ith_mid*nek1_pt;
|
||||
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
|
||||
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
|
||||
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
|
||||
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
|
||||
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
|
||||
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
|
||||
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||
}
|
||||
|
||||
barrier(barrier_data);
|
||||
|
||||
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
|
||||
int ik02 = iq2/rk2;
|
||||
int il = iq2 - ik02*rk2;
|
||||
auto Racc = qkv + iq2*nb1/sizeof(float);
|
||||
float M = -INFINITY, S = 0;
|
||||
for (int ig = 0; ig < nth_g; ++ig) {
|
||||
int istep_k = ik02*nth_g + ig;
|
||||
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
|
||||
const float * R = this_result + il*Dv;
|
||||
const float * Mj = this_result + Dv*rk2;
|
||||
const float * Sj = Mj + rk2;
|
||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int nth_k = nth/gcd;
|
||||
int nek2_k = nek2/gcd;
|
||||
int nchunk = nek2_k*nek1/32;
|
||||
@@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
|
||||
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
|
||||
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
|
||||
}
|
||||
|
||||
@@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
const float * Sj = Mj + rk2;
|
||||
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
|
||||
}
|
||||
if (sinks) {
|
||||
float s = ((const float *)sinks)[iq2];
|
||||
if (s > M) {
|
||||
float m = expf(M - s);
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= m;
|
||||
S = S*m + 1;
|
||||
} else {
|
||||
S += expf(s - M);
|
||||
}
|
||||
}
|
||||
float norm = S > 0 ? 1/S : 1;
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
|
||||
}
|
||||
@@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int counter = 0;
|
||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr;
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1g;
|
||||
int this_neq1 = std::min(neq1g, neq1-iq1);
|
||||
@@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
|
||||
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
|
||||
(const void *)((const char *)mask + iq1*stride_m),
|
||||
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
|
||||
scale, softcap,
|
||||
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const float * sinksf, // attention sinks
|
||||
int nsinks, // number of sinks
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
|
||||
@@ -120,16 +120,21 @@ struct MulMat {
|
||||
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
||||
}
|
||||
}
|
||||
inline void gelu(int n, const float * src, float * dst);
|
||||
inline void relu(int n, const float * src, float * dst);
|
||||
inline void silu(int n, const float * src, float * dst);
|
||||
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
||||
inline static void gelu(int n, const float * src, float * dst);
|
||||
inline static void relu(int n, const float * src, float * dst);
|
||||
inline static void silu(int n, const float * src, float * dst);
|
||||
inline static void swiglu_oai(int n, const float * src, float * dst);
|
||||
inline static void clamp_oai(int n, float *x);
|
||||
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
|
||||
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
|
||||
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
|
||||
else GGML_ABORT("fatal error");
|
||||
}
|
||||
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
||||
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
|
||||
const float * up_b, const float * gate_b,
|
||||
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
|
||||
#ifdef __aarch64__
|
||||
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
||||
#else
|
||||
@@ -137,6 +142,29 @@ struct MulMat {
|
||||
#endif
|
||||
auto op = ggml_unary_op(unary_op);
|
||||
float tmp[k_x_step*16];
|
||||
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
|
||||
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
if (gate_b) {
|
||||
auto b = gate_b + ix;
|
||||
auto x = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
|
||||
}
|
||||
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
|
||||
}
|
||||
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
if (up_b) {
|
||||
auto b = up_b + ix;
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
|
||||
}
|
||||
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
|
||||
clamp_oai(this_nrc_x, result);
|
||||
}
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
|
||||
}
|
||||
};
|
||||
if (func16 && nrc_y >= 16) {
|
||||
int n_step = (nrc_y - info.cur_y)/16;
|
||||
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||
@@ -144,15 +172,7 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < n_step; ++iy) {
|
||||
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < 16; ++ky) {
|
||||
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
}
|
||||
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < 16; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(func16, this_info, ix, this_nrc_x, 16);
|
||||
this_info.cur_y += 16;
|
||||
}
|
||||
}
|
||||
@@ -175,23 +195,11 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < my1; ++iy) {
|
||||
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny1; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
|
||||
this_info.cur_y += ny1;
|
||||
}
|
||||
for (int iy = 0; iy < my2; ++iy) {
|
||||
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny2; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
|
||||
this_info.cur_y += ny2;
|
||||
}
|
||||
}
|
||||
@@ -203,13 +211,7 @@ struct MulMat {
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
for (int iy = 0; iy < n_step; ++iy) {
|
||||
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < ny; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
|
||||
this_info.cur_y += ny;
|
||||
}
|
||||
}
|
||||
@@ -222,13 +224,7 @@ struct MulMat {
|
||||
auto this_info = info;
|
||||
this_info.s += ix;
|
||||
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
|
||||
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
|
||||
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
|
||||
for (int ky = 0; ky < n_left; ++ky) {
|
||||
auto result = this_info.dst_row(ky);
|
||||
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
|
||||
}
|
||||
process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||
int typeB, const void * B, long strideB,
|
||||
const char * up_b_c, const char * gate_b_c,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
|
||||
|
||||
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
|
||||
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
|
||||
auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
|
||||
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
|
||||
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
|
||||
nrc_x *= num_rows;
|
||||
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
||||
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
||||
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op);
|
||||
auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
|
||||
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
|
||||
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
|
||||
up_b, gate_b, info, nrc_x, Ny, unary_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: these swiglu_oai constants shouldn't be hard coded
|
||||
constexpr float k_swiglu_oai_alpha = 1.702f;
|
||||
constexpr float k_swiglu_oai_limit = 7.f;
|
||||
|
||||
void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||
// int i = 0;
|
||||
//#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
// {
|
||||
// auto max = _mm512_set1_ps(k_swiglu_oai_limit);
|
||||
// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha);
|
||||
// for (; i + 15 < n; i += 16) {
|
||||
// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max);
|
||||
// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha));
|
||||
// }
|
||||
// }
|
||||
//#endif
|
||||
//#if defined __AVX2__ && defined __FMA__
|
||||
// if (i + 7 < n) {
|
||||
// auto max = _mm256_set1_ps(k_swiglu_oai_limit);
|
||||
// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha);
|
||||
// for (; i + 7 < n; i += 8) {
|
||||
// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max);
|
||||
// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha));
|
||||
// }
|
||||
// }
|
||||
//#endif
|
||||
// for (; i < n; ++i) {
|
||||
// auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||
// }
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||
}
|
||||
}
|
||||
|
||||
void MulMat::clamp_oai(int n, float * x) {
|
||||
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
|
||||
}
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
void MulMat::gelu(int n, const float * x, float * y) {
|
||||
constexpr float GELU_COEF_A = 0.044715f;
|
||||
@@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
|
||||
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
|
||||
}
|
||||
|
||||
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
|
||||
// int i = 0;
|
||||
//#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
// {
|
||||
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
|
||||
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
|
||||
// for (; i + 15 < n; i += 16) {
|
||||
// auto xi = _mm512_loadu_ps(x + i);
|
||||
// auto mask = _mm512_cmp
|
||||
//
|
||||
// }
|
||||
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
|
||||
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
|
||||
// }
|
||||
//#endif
|
||||
//#if defined __AVX2__ && defined __FMA__
|
||||
// if (i + 7 < n) {
|
||||
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
|
||||
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
|
||||
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
|
||||
//
|
||||
// }
|
||||
//#endif
|
||||
// for (; i < n; ++i) {
|
||||
// auto xi = std::min(x[i], k_swiglu_oai_limit);
|
||||
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
void MulMat::silu(int n, const float * x, float * y) {
|
||||
int i = 0;
|
||||
#if defined __AVX512F__ && defined __AVX512DQ__
|
||||
@@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
[[maybe_unused]] int nsinks,
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
@@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
|
||||
if (Dk == 576 && Dv == 512) {
|
||||
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 192 && Dv == 128) {
|
||||
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 256 && Dv == 256) {
|
||||
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 128 && Dv == 128) {
|
||||
return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 96 && Dv == 96) {
|
||||
return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 64 && Dv == 64) {
|
||||
return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, M, S);
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
||||
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
|
||||
int typeA, const void * Aup, const void * Agate, long strideA,
|
||||
int typeB, const void * B, long strideB,
|
||||
const char * up_b, const char * gate_b,
|
||||
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
|
||||
|
||||
IQK_API int iqk_dequant_type(int type, int Ny);
|
||||
@@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
|
||||
@@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) {
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
const float32x4_t neg_x = vmulq_f32(alpha, x);
|
||||
const float32x4_t exp_neg_x = v_expf(neg_x);
|
||||
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
|
||||
return vdivq_f32(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
|
||||
const float32x4_t one = vdupq_n_f32(1.0f);
|
||||
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
|
||||
@@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) {
|
||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline __m512 v_silu_oai(__m512 x, __m512 alpha) {
|
||||
const __m512 one = _mm512_set1_ps(1);
|
||||
const __m512 neg_x = _mm512_mul_ps(alpha, x);
|
||||
const __m512 exp_neg_x = v_expf(neg_x);
|
||||
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
|
||||
return _mm512_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline __m512 v_clamp_max(__m512 x, __m512 max) {
|
||||
auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ);
|
||||
return _mm512_mask_blend_ps(mask, x, max);
|
||||
}
|
||||
#endif // __AVX512__
|
||||
|
||||
#if defined(__AVX2__) && defined(__FMA__)
|
||||
@@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
|
||||
}
|
||||
static inline __m256 v_silu(__m256 x) {
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
const __m256 zero = _mm256_setzero_ps();
|
||||
const __m256 neg_x = _mm256_sub_ps(zero, x);
|
||||
const __m256 exp_neg_x = v_expf(neg_x);
|
||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline __m256 v_silu_oai(__m256 x, __m256 alpha) {
|
||||
const __m256 one = _mm256_set1_ps(1);
|
||||
const __m256 neg_x = _mm256_mul_ps(alpha, x);
|
||||
const __m256 exp_neg_x = v_expf(neg_x);
|
||||
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
|
||||
return _mm256_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
static inline __m256 v_clamp_max(__m256 x, __m256 max) {
|
||||
auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ);
|
||||
return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x));
|
||||
}
|
||||
|
||||
#endif // __AVX2__
|
||||
|
||||
|
||||
Reference in New Issue
Block a user