mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 11:51:53 +00:00
Merge branch 'main' into s6/dots
This commit is contained in:
@@ -502,5 +502,7 @@ void llama_sampling_accept(
|
||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
}
|
||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||
if (ctx_sampling->smpl) {
|
||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,6 +636,9 @@ class Model:
|
||||
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
||||
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
||||
res = "deepseek-v3"
|
||||
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
|
||||
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
|
||||
res = "seed-coder"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -1520,6 +1523,17 @@ class LlamaModel(Model):
|
||||
special_vocab._set_special_token("eot", 32010)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
# Apply to Seed-Coder only (and ignore otherwise)
|
||||
if self.hparams.get("vocab_size", 32000) == 155136:
|
||||
special_vocab = gguf.SpecialVocab(
|
||||
self.dir_model, load_merges=False,
|
||||
special_token_types = ['prefix', 'suffix', 'middle', 'eot']
|
||||
)
|
||||
special_vocab._set_special_token("prefix", 124)
|
||||
special_vocab._set_special_token("suffix", 125)
|
||||
special_vocab._set_special_token("middle", 126)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
@@ -95,6 +95,7 @@ models = [
|
||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
|
||||
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1957,7 +1957,7 @@ extern "C" {
|
||||
int min_entries,
|
||||
float thresh);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 32
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
|
||||
@@ -971,12 +971,14 @@ GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, v
|
||||
GGML_UNUSED(user_data);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_RPC
|
||||
GGML_CALL static ggml_backend_t ggml_backend_reg_rpc_init(const char* params, void* user_data) {
|
||||
return ggml_backend_rpc_init((const char*)user_data);
|
||||
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(user_data);
|
||||
}
|
||||
#endif
|
||||
|
||||
// multi-buffer buffer
|
||||
|
||||
|
||||
@@ -2186,6 +2186,7 @@ struct mmid_row_mapping {
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
template <typename data_t = float>
|
||||
static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous,
|
||||
const mmid_row_mapping * __restrict__ row_mapping,
|
||||
int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) {
|
||||
@@ -2194,8 +2195,8 @@ static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_or
|
||||
const int32_t i11 = row_mapping[i].i1 % ne11;
|
||||
const int32_t i12 = row_mapping[i].i2;
|
||||
|
||||
float * src_row_contiguous = (float *)(src_contiguous + i*nb11);
|
||||
const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12);
|
||||
data_t * src_row_contiguous = (data_t *)(src_contiguous + i*nb11);
|
||||
const data_t * src_row_original = (const data_t *)(src_original + i11*nb11 + i12*nb12);
|
||||
|
||||
for (int j = threadIdx.x; j < ne10; j += blockDim.x) {
|
||||
src_row_contiguous[j] = src_row_original[j];
|
||||
@@ -2673,6 +2674,17 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]);
|
||||
ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
|
||||
bool use_quantized_src1 = false;
|
||||
int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0;
|
||||
if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) {
|
||||
src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
|
||||
src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1);
|
||||
src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq);
|
||||
src1_quantized.alloc(src1_quantized_size);
|
||||
use_quantized_src1 = true;
|
||||
}
|
||||
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
@@ -2704,7 +2716,13 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
if (num_src1_rows == 0) continue;
|
||||
size_t mapping_offset = cum_moe_counts[i02];
|
||||
|
||||
{
|
||||
if (use_quantized_src1) {
|
||||
quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset),
|
||||
src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
src1_row.data = src1_quantized.get();
|
||||
}
|
||||
else {
|
||||
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
@@ -2719,9 +2737,9 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
|
||||
src1_row.ne[1] = num_src1_rows;
|
||||
src1_row.nb[1] = nb11;
|
||||
src1_row.nb[2] = num_src1_rows*nb11;
|
||||
src1_row.nb[3] = num_src1_rows*nb11;
|
||||
src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11;
|
||||
src1_row.nb[2] = num_src1_rows*src1_row.nb[1];
|
||||
src1_row.nb[3] = num_src1_rows*src1_row.nb[1];
|
||||
|
||||
dst_row.ne[1] = num_src1_rows;
|
||||
dst_row.nb[1] = nb1;
|
||||
@@ -2729,11 +2747,21 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
dst_row.data = dst_up_contiguous.get();
|
||||
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
|
||||
if (use_quantized_src1) {
|
||||
ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
|
||||
0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
|
||||
} else {
|
||||
ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
dst_row.data = dst_gate_contiguous.get();
|
||||
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
|
||||
if (use_quantized_src1) {
|
||||
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
|
||||
0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream);
|
||||
} else {
|
||||
ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
|
||||
|
||||
75
ggml/src/ggml-cuda/iqk_cuda_common.h
Normal file
75
ggml/src/ggml-cuda/iqk_cuda_common.h
Normal file
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
static const __device__ uint32_t iq2k_table[512] = {
|
||||
0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311,
|
||||
0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111,
|
||||
0xe1f3e1e1, 0xe1f3e1f3, 0xe1f3e101, 0xe1f3e111, 0xe1f3f3e1, 0xe1f3f3f3, 0xe1f3f301, 0xe1f3f311,
|
||||
0xe1f301e1, 0xe1f301f3, 0xe1f30101, 0xe1f30111, 0xe1f311e1, 0xe1f311f3, 0xe1f31101, 0xe1f31111,
|
||||
0xe101e1e1, 0xe101e1f3, 0xe101e101, 0xe101e111, 0xe101f3e1, 0xe101f3f3, 0xe101f301, 0xe101f311,
|
||||
0xe10101e1, 0xe10101f3, 0xe1010101, 0xe1010111, 0xe10111e1, 0xe10111f3, 0xe1011101, 0xe1011111,
|
||||
0xe111e1e1, 0xe111e1f3, 0xe111e101, 0xe111e111, 0xe111f3e1, 0xe111f3f3, 0xe111f301, 0xe111f311,
|
||||
0xe11101e1, 0xe11101f3, 0xe1110101, 0xe1110111, 0xe11111e1, 0xe11111f3, 0xe1111101, 0xe1111111,
|
||||
0xf3e1e1e1, 0xf3e1e1f3, 0xf3e1e101, 0xf3e1e111, 0xf3e1f3e1, 0xf3e1f3f3, 0xf3e1f301, 0xf3e1f311,
|
||||
0xf3e101e1, 0xf3e101f3, 0xf3e10101, 0xf3e10111, 0xf3e111e1, 0xf3e111f3, 0xf3e11101, 0xf3e11111,
|
||||
0xf3f3e1e1, 0xf3f3e1f3, 0xf3f3e101, 0xf3f3e111, 0xf3f3f3e1, 0xf3f3f3f3, 0xf3f3f301, 0xf3f3f311,
|
||||
0xf3f301e1, 0xf3f301f3, 0xf3f30101, 0xf3f30111, 0xf3f311e1, 0xf3f311f3, 0xf3f31101, 0xf3f31111,
|
||||
0xf301e1e1, 0xf301e1f3, 0xf301e101, 0xf301e111, 0xf301f3e1, 0xf301f3f3, 0xf301f301, 0xf301f311,
|
||||
0xf30101e1, 0xf30101f3, 0xf3010101, 0xf3010111, 0xf30111e1, 0xf30111f3, 0xf3011101, 0xf3011111,
|
||||
0xf311e1e1, 0xf311e1f3, 0xf311e101, 0xf311e111, 0xf311f3e1, 0xf311f3f3, 0xf311f301, 0xf311f311,
|
||||
0xf31101e1, 0xf31101f3, 0xf3110101, 0xf3110111, 0xf31111e1, 0xf31111f3, 0xf3111101, 0xf3111111,
|
||||
0x01e1e1e1, 0x01e1e1f3, 0x01e1e101, 0x01e1e111, 0x01e1f3e1, 0x01e1f3f3, 0x01e1f301, 0x01e1f311,
|
||||
0x01e101e1, 0x01e101f3, 0x01e10101, 0x01e10111, 0x01e111e1, 0x01e111f3, 0x01e11101, 0x01e11111,
|
||||
0x01f3e1e1, 0x01f3e1f3, 0x01f3e101, 0x01f3e111, 0x01f3f3e1, 0x01f3f3f3, 0x01f3f301, 0x01f3f311,
|
||||
0x01f301e1, 0x01f301f3, 0x01f30101, 0x01f30111, 0x01f311e1, 0x01f311f3, 0x01f31101, 0x01f31111,
|
||||
0x0101e1e1, 0x0101e1f3, 0x0101e101, 0x0101e111, 0x0101f3e1, 0x0101f3f3, 0x0101f301, 0x0101f311,
|
||||
0x010101e1, 0x010101f3, 0x01010101, 0x01010111, 0x010111e1, 0x010111f3, 0x01011101, 0x01011111,
|
||||
0x0111e1e1, 0x0111e1f3, 0x0111e101, 0x0111e111, 0x0111f3e1, 0x0111f3f3, 0x0111f301, 0x0111f311,
|
||||
0x011101e1, 0x011101f3, 0x01110101, 0x01110111, 0x011111e1, 0x011111f3, 0x01111101, 0x01111111,
|
||||
0x11e1e1e1, 0x11e1e1f3, 0x11e1e101, 0x11e1e111, 0x11e1f3e1, 0x11e1f3f3, 0x11e1f301, 0x11e1f311,
|
||||
0x11e101e1, 0x11e101f3, 0x11e10101, 0x11e10111, 0x11e111e1, 0x11e111f3, 0x11e11101, 0x11e11111,
|
||||
0x11f3e1e1, 0x11f3e1f3, 0x11f3e101, 0x11f3e111, 0x11f3f3e1, 0x11f3f3f3, 0x11f3f301, 0x11f3f311,
|
||||
0x11f301e1, 0x11f301f3, 0x11f30101, 0x11f30111, 0x11f311e1, 0x11f311f3, 0x11f31101, 0x11f31111,
|
||||
0x1101e1e1, 0x1101e1f3, 0x1101e101, 0x1101e111, 0x1101f3e1, 0x1101f3f3, 0x1101f301, 0x1101f311,
|
||||
0x110101e1, 0x110101f3, 0x11010101, 0x11010111, 0x110111e1, 0x110111f3, 0x11011101, 0x11011111,
|
||||
0x1111e1e1, 0x1111e1f3, 0x1111e101, 0x1111e111, 0x1111f3e1, 0x1111f3f3, 0x1111f301, 0x1111f311,
|
||||
0x111101e1, 0x111101f3, 0x11110101, 0x11110111, 0x111111e1, 0x111111f3, 0x11111101, 0x11111111,
|
||||
0xe6e6e6e6, 0xe6e6e6f8, 0xe6e6e606, 0xe6e6e616, 0xe6e6f8e6, 0xe6e6f8f8, 0xe6e6f806, 0xe6e6f816,
|
||||
0xe6e606e6, 0xe6e606f8, 0xe6e60606, 0xe6e60616, 0xe6e616e6, 0xe6e616f8, 0xe6e61606, 0xe6e61616,
|
||||
0xe6f8e6e6, 0xe6f8e6f8, 0xe6f8e606, 0xe6f8e616, 0xe6f8f8e6, 0xe6f8f8f8, 0xe6f8f806, 0xe6f8f816,
|
||||
0xe6f806e6, 0xe6f806f8, 0xe6f80606, 0xe6f80616, 0xe6f816e6, 0xe6f816f8, 0xe6f81606, 0xe6f81616,
|
||||
0xe606e6e6, 0xe606e6f8, 0xe606e606, 0xe606e616, 0xe606f8e6, 0xe606f8f8, 0xe606f806, 0xe606f816,
|
||||
0xe60606e6, 0xe60606f8, 0xe6060606, 0xe6060616, 0xe60616e6, 0xe60616f8, 0xe6061606, 0xe6061616,
|
||||
0xe616e6e6, 0xe616e6f8, 0xe616e606, 0xe616e616, 0xe616f8e6, 0xe616f8f8, 0xe616f806, 0xe616f816,
|
||||
0xe61606e6, 0xe61606f8, 0xe6160606, 0xe6160616, 0xe61616e6, 0xe61616f8, 0xe6161606, 0xe6161616,
|
||||
0xf8e6e6e6, 0xf8e6e6f8, 0xf8e6e606, 0xf8e6e616, 0xf8e6f8e6, 0xf8e6f8f8, 0xf8e6f806, 0xf8e6f816,
|
||||
0xf8e606e6, 0xf8e606f8, 0xf8e60606, 0xf8e60616, 0xf8e616e6, 0xf8e616f8, 0xf8e61606, 0xf8e61616,
|
||||
0xf8f8e6e6, 0xf8f8e6f8, 0xf8f8e606, 0xf8f8e616, 0xf8f8f8e6, 0xf8f8f8f8, 0xf8f8f806, 0xf8f8f816,
|
||||
0xf8f806e6, 0xf8f806f8, 0xf8f80606, 0xf8f80616, 0xf8f816e6, 0xf8f816f8, 0xf8f81606, 0xf8f81616,
|
||||
0xf806e6e6, 0xf806e6f8, 0xf806e606, 0xf806e616, 0xf806f8e6, 0xf806f8f8, 0xf806f806, 0xf806f816,
|
||||
0xf80606e6, 0xf80606f8, 0xf8060606, 0xf8060616, 0xf80616e6, 0xf80616f8, 0xf8061606, 0xf8061616,
|
||||
0xf816e6e6, 0xf816e6f8, 0xf816e606, 0xf816e616, 0xf816f8e6, 0xf816f8f8, 0xf816f806, 0xf816f816,
|
||||
0xf81606e6, 0xf81606f8, 0xf8160606, 0xf8160616, 0xf81616e6, 0xf81616f8, 0xf8161606, 0xf8161616,
|
||||
0x06e6e6e6, 0x06e6e6f8, 0x06e6e606, 0x06e6e616, 0x06e6f8e6, 0x06e6f8f8, 0x06e6f806, 0x06e6f816,
|
||||
0x06e606e6, 0x06e606f8, 0x06e60606, 0x06e60616, 0x06e616e6, 0x06e616f8, 0x06e61606, 0x06e61616,
|
||||
0x06f8e6e6, 0x06f8e6f8, 0x06f8e606, 0x06f8e616, 0x06f8f8e6, 0x06f8f8f8, 0x06f8f806, 0x06f8f816,
|
||||
0x06f806e6, 0x06f806f8, 0x06f80606, 0x06f80616, 0x06f816e6, 0x06f816f8, 0x06f81606, 0x06f81616,
|
||||
0x0606e6e6, 0x0606e6f8, 0x0606e606, 0x0606e616, 0x0606f8e6, 0x0606f8f8, 0x0606f806, 0x0606f816,
|
||||
0x060606e6, 0x060606f8, 0x06060606, 0x06060616, 0x060616e6, 0x060616f8, 0x06061606, 0x06061616,
|
||||
0x0616e6e6, 0x0616e6f8, 0x0616e606, 0x0616e616, 0x0616f8e6, 0x0616f8f8, 0x0616f806, 0x0616f816,
|
||||
0x061606e6, 0x061606f8, 0x06160606, 0x06160616, 0x061616e6, 0x061616f8, 0x06161606, 0x06161616,
|
||||
0x16e6e6e6, 0x16e6e6f8, 0x16e6e606, 0x16e6e616, 0x16e6f8e6, 0x16e6f8f8, 0x16e6f806, 0x16e6f816,
|
||||
0x16e606e6, 0x16e606f8, 0x16e60606, 0x16e60616, 0x16e616e6, 0x16e616f8, 0x16e61606, 0x16e61616,
|
||||
0x16f8e6e6, 0x16f8e6f8, 0x16f8e606, 0x16f8e616, 0x16f8f8e6, 0x16f8f8f8, 0x16f8f806, 0x16f8f816,
|
||||
0x16f806e6, 0x16f806f8, 0x16f80606, 0x16f80616, 0x16f816e6, 0x16f816f8, 0x16f81606, 0x16f81616,
|
||||
0x1606e6e6, 0x1606e6f8, 0x1606e606, 0x1606e616, 0x1606f8e6, 0x1606f8f8, 0x1606f806, 0x1606f816,
|
||||
0x160606e6, 0x160606f8, 0x16060606, 0x16060616, 0x160616e6, 0x160616f8, 0x16061606, 0x16061616,
|
||||
0x1616e6e6, 0x1616e6f8, 0x1616e606, 0x1616e616, 0x1616f8e6, 0x1616f8f8, 0x1616f806, 0x1616f816,
|
||||
0x161606e6, 0x161606f8, 0x16160606, 0x16160616, 0x161616e6, 0x161616f8, 0x16161606, 0x16161616,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_4(const uint32_t idx, const int * values) {
|
||||
return values[ggml_cuda_dp4a(idx, 0x40100401, 0)];
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//
|
||||
|
||||
#include "iqk_mmvq.cuh"
|
||||
#include "iqk_cuda_common.h"
|
||||
|
||||
typedef void (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float *);
|
||||
|
||||
@@ -785,77 +786,6 @@ __device__ __forceinline__ void vec_dot_iq6_k_q8_1(
|
||||
*result += d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]);
|
||||
}
|
||||
|
||||
static const __device__ uint32_t iq2k_table[512] = {
|
||||
0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311,
|
||||
0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111,
|
||||
0xe1f3e1e1, 0xe1f3e1f3, 0xe1f3e101, 0xe1f3e111, 0xe1f3f3e1, 0xe1f3f3f3, 0xe1f3f301, 0xe1f3f311,
|
||||
0xe1f301e1, 0xe1f301f3, 0xe1f30101, 0xe1f30111, 0xe1f311e1, 0xe1f311f3, 0xe1f31101, 0xe1f31111,
|
||||
0xe101e1e1, 0xe101e1f3, 0xe101e101, 0xe101e111, 0xe101f3e1, 0xe101f3f3, 0xe101f301, 0xe101f311,
|
||||
0xe10101e1, 0xe10101f3, 0xe1010101, 0xe1010111, 0xe10111e1, 0xe10111f3, 0xe1011101, 0xe1011111,
|
||||
0xe111e1e1, 0xe111e1f3, 0xe111e101, 0xe111e111, 0xe111f3e1, 0xe111f3f3, 0xe111f301, 0xe111f311,
|
||||
0xe11101e1, 0xe11101f3, 0xe1110101, 0xe1110111, 0xe11111e1, 0xe11111f3, 0xe1111101, 0xe1111111,
|
||||
0xf3e1e1e1, 0xf3e1e1f3, 0xf3e1e101, 0xf3e1e111, 0xf3e1f3e1, 0xf3e1f3f3, 0xf3e1f301, 0xf3e1f311,
|
||||
0xf3e101e1, 0xf3e101f3, 0xf3e10101, 0xf3e10111, 0xf3e111e1, 0xf3e111f3, 0xf3e11101, 0xf3e11111,
|
||||
0xf3f3e1e1, 0xf3f3e1f3, 0xf3f3e101, 0xf3f3e111, 0xf3f3f3e1, 0xf3f3f3f3, 0xf3f3f301, 0xf3f3f311,
|
||||
0xf3f301e1, 0xf3f301f3, 0xf3f30101, 0xf3f30111, 0xf3f311e1, 0xf3f311f3, 0xf3f31101, 0xf3f31111,
|
||||
0xf301e1e1, 0xf301e1f3, 0xf301e101, 0xf301e111, 0xf301f3e1, 0xf301f3f3, 0xf301f301, 0xf301f311,
|
||||
0xf30101e1, 0xf30101f3, 0xf3010101, 0xf3010111, 0xf30111e1, 0xf30111f3, 0xf3011101, 0xf3011111,
|
||||
0xf311e1e1, 0xf311e1f3, 0xf311e101, 0xf311e111, 0xf311f3e1, 0xf311f3f3, 0xf311f301, 0xf311f311,
|
||||
0xf31101e1, 0xf31101f3, 0xf3110101, 0xf3110111, 0xf31111e1, 0xf31111f3, 0xf3111101, 0xf3111111,
|
||||
0x01e1e1e1, 0x01e1e1f3, 0x01e1e101, 0x01e1e111, 0x01e1f3e1, 0x01e1f3f3, 0x01e1f301, 0x01e1f311,
|
||||
0x01e101e1, 0x01e101f3, 0x01e10101, 0x01e10111, 0x01e111e1, 0x01e111f3, 0x01e11101, 0x01e11111,
|
||||
0x01f3e1e1, 0x01f3e1f3, 0x01f3e101, 0x01f3e111, 0x01f3f3e1, 0x01f3f3f3, 0x01f3f301, 0x01f3f311,
|
||||
0x01f301e1, 0x01f301f3, 0x01f30101, 0x01f30111, 0x01f311e1, 0x01f311f3, 0x01f31101, 0x01f31111,
|
||||
0x0101e1e1, 0x0101e1f3, 0x0101e101, 0x0101e111, 0x0101f3e1, 0x0101f3f3, 0x0101f301, 0x0101f311,
|
||||
0x010101e1, 0x010101f3, 0x01010101, 0x01010111, 0x010111e1, 0x010111f3, 0x01011101, 0x01011111,
|
||||
0x0111e1e1, 0x0111e1f3, 0x0111e101, 0x0111e111, 0x0111f3e1, 0x0111f3f3, 0x0111f301, 0x0111f311,
|
||||
0x011101e1, 0x011101f3, 0x01110101, 0x01110111, 0x011111e1, 0x011111f3, 0x01111101, 0x01111111,
|
||||
0x11e1e1e1, 0x11e1e1f3, 0x11e1e101, 0x11e1e111, 0x11e1f3e1, 0x11e1f3f3, 0x11e1f301, 0x11e1f311,
|
||||
0x11e101e1, 0x11e101f3, 0x11e10101, 0x11e10111, 0x11e111e1, 0x11e111f3, 0x11e11101, 0x11e11111,
|
||||
0x11f3e1e1, 0x11f3e1f3, 0x11f3e101, 0x11f3e111, 0x11f3f3e1, 0x11f3f3f3, 0x11f3f301, 0x11f3f311,
|
||||
0x11f301e1, 0x11f301f3, 0x11f30101, 0x11f30111, 0x11f311e1, 0x11f311f3, 0x11f31101, 0x11f31111,
|
||||
0x1101e1e1, 0x1101e1f3, 0x1101e101, 0x1101e111, 0x1101f3e1, 0x1101f3f3, 0x1101f301, 0x1101f311,
|
||||
0x110101e1, 0x110101f3, 0x11010101, 0x11010111, 0x110111e1, 0x110111f3, 0x11011101, 0x11011111,
|
||||
0x1111e1e1, 0x1111e1f3, 0x1111e101, 0x1111e111, 0x1111f3e1, 0x1111f3f3, 0x1111f301, 0x1111f311,
|
||||
0x111101e1, 0x111101f3, 0x11110101, 0x11110111, 0x111111e1, 0x111111f3, 0x11111101, 0x11111111,
|
||||
0xe6e6e6e6, 0xe6e6e6f8, 0xe6e6e606, 0xe6e6e616, 0xe6e6f8e6, 0xe6e6f8f8, 0xe6e6f806, 0xe6e6f816,
|
||||
0xe6e606e6, 0xe6e606f8, 0xe6e60606, 0xe6e60616, 0xe6e616e6, 0xe6e616f8, 0xe6e61606, 0xe6e61616,
|
||||
0xe6f8e6e6, 0xe6f8e6f8, 0xe6f8e606, 0xe6f8e616, 0xe6f8f8e6, 0xe6f8f8f8, 0xe6f8f806, 0xe6f8f816,
|
||||
0xe6f806e6, 0xe6f806f8, 0xe6f80606, 0xe6f80616, 0xe6f816e6, 0xe6f816f8, 0xe6f81606, 0xe6f81616,
|
||||
0xe606e6e6, 0xe606e6f8, 0xe606e606, 0xe606e616, 0xe606f8e6, 0xe606f8f8, 0xe606f806, 0xe606f816,
|
||||
0xe60606e6, 0xe60606f8, 0xe6060606, 0xe6060616, 0xe60616e6, 0xe60616f8, 0xe6061606, 0xe6061616,
|
||||
0xe616e6e6, 0xe616e6f8, 0xe616e606, 0xe616e616, 0xe616f8e6, 0xe616f8f8, 0xe616f806, 0xe616f816,
|
||||
0xe61606e6, 0xe61606f8, 0xe6160606, 0xe6160616, 0xe61616e6, 0xe61616f8, 0xe6161606, 0xe6161616,
|
||||
0xf8e6e6e6, 0xf8e6e6f8, 0xf8e6e606, 0xf8e6e616, 0xf8e6f8e6, 0xf8e6f8f8, 0xf8e6f806, 0xf8e6f816,
|
||||
0xf8e606e6, 0xf8e606f8, 0xf8e60606, 0xf8e60616, 0xf8e616e6, 0xf8e616f8, 0xf8e61606, 0xf8e61616,
|
||||
0xf8f8e6e6, 0xf8f8e6f8, 0xf8f8e606, 0xf8f8e616, 0xf8f8f8e6, 0xf8f8f8f8, 0xf8f8f806, 0xf8f8f816,
|
||||
0xf8f806e6, 0xf8f806f8, 0xf8f80606, 0xf8f80616, 0xf8f816e6, 0xf8f816f8, 0xf8f81606, 0xf8f81616,
|
||||
0xf806e6e6, 0xf806e6f8, 0xf806e606, 0xf806e616, 0xf806f8e6, 0xf806f8f8, 0xf806f806, 0xf806f816,
|
||||
0xf80606e6, 0xf80606f8, 0xf8060606, 0xf8060616, 0xf80616e6, 0xf80616f8, 0xf8061606, 0xf8061616,
|
||||
0xf816e6e6, 0xf816e6f8, 0xf816e606, 0xf816e616, 0xf816f8e6, 0xf816f8f8, 0xf816f806, 0xf816f816,
|
||||
0xf81606e6, 0xf81606f8, 0xf8160606, 0xf8160616, 0xf81616e6, 0xf81616f8, 0xf8161606, 0xf8161616,
|
||||
0x06e6e6e6, 0x06e6e6f8, 0x06e6e606, 0x06e6e616, 0x06e6f8e6, 0x06e6f8f8, 0x06e6f806, 0x06e6f816,
|
||||
0x06e606e6, 0x06e606f8, 0x06e60606, 0x06e60616, 0x06e616e6, 0x06e616f8, 0x06e61606, 0x06e61616,
|
||||
0x06f8e6e6, 0x06f8e6f8, 0x06f8e606, 0x06f8e616, 0x06f8f8e6, 0x06f8f8f8, 0x06f8f806, 0x06f8f816,
|
||||
0x06f806e6, 0x06f806f8, 0x06f80606, 0x06f80616, 0x06f816e6, 0x06f816f8, 0x06f81606, 0x06f81616,
|
||||
0x0606e6e6, 0x0606e6f8, 0x0606e606, 0x0606e616, 0x0606f8e6, 0x0606f8f8, 0x0606f806, 0x0606f816,
|
||||
0x060606e6, 0x060606f8, 0x06060606, 0x06060616, 0x060616e6, 0x060616f8, 0x06061606, 0x06061616,
|
||||
0x0616e6e6, 0x0616e6f8, 0x0616e606, 0x0616e616, 0x0616f8e6, 0x0616f8f8, 0x0616f806, 0x0616f816,
|
||||
0x061606e6, 0x061606f8, 0x06160606, 0x06160616, 0x061616e6, 0x061616f8, 0x06161606, 0x06161616,
|
||||
0x16e6e6e6, 0x16e6e6f8, 0x16e6e606, 0x16e6e616, 0x16e6f8e6, 0x16e6f8f8, 0x16e6f806, 0x16e6f816,
|
||||
0x16e606e6, 0x16e606f8, 0x16e60606, 0x16e60616, 0x16e616e6, 0x16e616f8, 0x16e61606, 0x16e61616,
|
||||
0x16f8e6e6, 0x16f8e6f8, 0x16f8e606, 0x16f8e616, 0x16f8f8e6, 0x16f8f8f8, 0x16f8f806, 0x16f8f816,
|
||||
0x16f806e6, 0x16f806f8, 0x16f80606, 0x16f80616, 0x16f816e6, 0x16f816f8, 0x16f81606, 0x16f81616,
|
||||
0x1606e6e6, 0x1606e6f8, 0x1606e606, 0x1606e616, 0x1606f8e6, 0x1606f8f8, 0x1606f806, 0x1606f816,
|
||||
0x160606e6, 0x160606f8, 0x16060606, 0x16060616, 0x160616e6, 0x160616f8, 0x16061606, 0x16061616,
|
||||
0x1616e6e6, 0x1616e6f8, 0x1616e606, 0x1616e616, 0x1616f8e6, 0x1616f8f8, 0x1616f806, 0x1616f816,
|
||||
0x161606e6, 0x161606f8, 0x16160606, 0x16160616, 0x161616e6, 0x161616f8, 0x16161606, 0x16161616,
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int * values) {
|
||||
return values[a8[0] | (a8[1] << 2) | (a8[2] << 4) | (a8[3] << 6)];
|
||||
}
|
||||
|
||||
#define VDR_IQ2_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ2_K_Q8_1_MMQ 4
|
||||
|
||||
@@ -881,7 +811,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
|
||||
uint32_t val1 = q2[0], val2 = q2[1];
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)&aux32;
|
||||
int v1, v2;
|
||||
|
||||
// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
|
||||
@@ -892,23 +821,23 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
|
||||
const int8_t * s8 = (const int8_t *)&s32;
|
||||
|
||||
aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * s8[0];
|
||||
|
||||
aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x04) << 6);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[1];
|
||||
|
||||
aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x10) << 4);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[2];
|
||||
|
||||
aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x40) << 2);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
|
||||
|
||||
*result += __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
@@ -941,7 +870,6 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
|
||||
uint32_t val1 = q2[0] | (q2[1] << 16), val2 = q2[2] | (q2[3] << 16);
|
||||
|
||||
uint32_t aux32[2];
|
||||
const uint8_t * a8 = (const uint8_t *)&aux32;
|
||||
int v1, v2;
|
||||
|
||||
int32_t scales32;
|
||||
@@ -954,23 +882,23 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
|
||||
s8[3] += ((extra >> 7) & 0x10);
|
||||
|
||||
aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi1 = ggml_cuda_dp4a(v2, q8_1[1], ggml_cuda_dp4a(v1, q8_1[0], 0)) * s8[0];
|
||||
|
||||
aux32[0] = ((val1 >> 2) & 0x03030303); aux32[1] = ((val2 >> 2) & 0x03030303); values = all_values + ((extra & 0x02) << 7);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi2 = ggml_cuda_dp4a(v2, q8_2[1], ggml_cuda_dp4a(v1, q8_2[0], 0)) * s8[2];
|
||||
|
||||
aux32[0] = ((val1 >> 4) & 0x03030303); aux32[1] = ((val2 >> 4) & 0x03030303); values = all_values + ((extra & 0x04) << 6);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi3 = ggml_cuda_dp4a(v2, q8_3[1], ggml_cuda_dp4a(v1, q8_3[0], 0)) * s8[1];
|
||||
|
||||
aux32[0] = ((val1 >> 6) & 0x03030303); aux32[1] = ((val2 >> 6) & 0x03030303); values = all_values + ((extra & 0x08) << 5);
|
||||
v1 = int_from_table_4(a8 + 0, values);
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
v1 = int_from_table_4(aux32[0], values);
|
||||
v2 = int_from_table_4(aux32[1], values);
|
||||
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
|
||||
|
||||
*result += scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
@@ -1000,20 +928,19 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
|
||||
int2 val1;
|
||||
const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is;
|
||||
int aux32[2];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto values1 = all_values + (((bq2->extra[i+4*is] >> ib32) & 1) << 8);
|
||||
int sumi1 = 0;
|
||||
aux32[0] = ((q2[i] >> 0) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 2) & 0x03030303);
|
||||
val1.x = int_from_table_4(aux8+0, values1);
|
||||
val1.y = int_from_table_4(aux8+4, values1);
|
||||
val1.x = int_from_table_4(aux32[0], values1);
|
||||
val1.y = int_from_table_4(aux32[1], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[1], sumi1));
|
||||
aux32[0] = ((q2[i] >> 4) & 0x03030303);
|
||||
aux32[1] = ((q2[i] >> 6) & 0x03030303);
|
||||
val1.x = int_from_table_4(aux8+0, values1);
|
||||
val1.y = int_from_table_4(aux8+4, values1);
|
||||
val1.x = int_from_table_4(aux32[0], values1);
|
||||
val1.y = int_from_table_4(aux32[1], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[2], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
|
||||
const float d = __half2float(bq2->d[i]) * d8;
|
||||
result[i] += d * sumi1 * s8[i];
|
||||
@@ -1114,7 +1041,6 @@ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1(
|
||||
const int ib128 = iqs/4; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
|
||||
// Each thread processes 8 quants in each of the 4 32-blocks
|
||||
const int il8 = iqs%4; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
|
||||
const int shift = 4*(il8/2);
|
||||
|
||||
const uint16_t * ql = (const uint16_t *)bq3->qs + 16*ib128 + 4*il8;
|
||||
const uint16_t * qh = (const uint16_t *)bq3->qh + 4*il8;
|
||||
|
||||
@@ -174,11 +174,13 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
mmq_supported = ne11 < 1536;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
mmq_supported = ne11 < 2048;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "common.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
#include "mma.cuh"
|
||||
#include "iqk_cuda_common.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
@@ -2492,12 +2493,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
|
||||
const int kqsx = threadIdx.x%16;
|
||||
|
||||
auto values = iq2nl_values;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
|
||||
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
|
||||
@@ -2511,26 +2510,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
uint16_t extra = bxi->extra >> 4*(kqsx/8);
|
||||
int q2 = get_int_b2(bxi->qs, kqsx);
|
||||
|
||||
aux32[0] = ((q2 >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101);
|
||||
aux32[1] = ((q2 >> 2) & 0x03030303) | (((extra << 1) & 4) * 0x01010101);
|
||||
aux32[2] = ((q2 >> 4) & 0x03030303) | (((extra >> 0) & 4) * 0x01010101);
|
||||
aux32[3] = ((q2 >> 6) & 0x03030303) | (((extra >> 1) & 4) * 0x01010101);
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
|
||||
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5));
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5));
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -2573,10 +2562,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
constexpr int qstep = 8;
|
||||
const int kqsx = threadIdx.x % qstep;
|
||||
|
||||
auto values = iq2nl_values;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
|
||||
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
|
||||
@@ -2587,6 +2572,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
|
||||
const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0;
|
||||
|
||||
auto all_values = (const int *)iq2k_table;
|
||||
|
||||
const float d = bxi->d;
|
||||
|
||||
uint16_t extra = bxi->extra >> (kqsx/4);
|
||||
@@ -2595,28 +2582,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
for (int l = 0; l < qstep/4; ++l) {
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
|
||||
aux32[0] = ((ql >> 0) & 0x03030303) | (((extra << 2) & 4) * 0x01010101);
|
||||
aux32[1] = ((ql >> 2) & 0x03030303) | (((extra << 0) & 4) * 0x01010101);
|
||||
aux32[2] = ((ql >> 4) & 0x03030303) | (((extra >> 2) & 4) * 0x01010101);
|
||||
aux32[3] = ((ql >> 6) & 0x03030303) | (((extra >> 4) & 4) * 0x01010101);
|
||||
extra >>= 8;
|
||||
|
||||
const char4 val0 = make_char4(values[aux8[ 0]], values[aux8[ 1]], values[aux8[ 2]], values[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values[aux8[ 4]], values[aux8[ 5]], values[aux8[ 6]], values[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values[aux8[ 8]], values[aux8[ 9]], values[aux8[10]], values[aux8[11]]);
|
||||
const char4 val3 = make_char4(values[aux8[12]], values[aux8[13]], values[aux8[14]], values[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4));
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2));
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4));
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2));
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
extra >>= 8;
|
||||
}
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
|
||||
@@ -166,6 +166,98 @@ static __global__ void quantize_mmq_q8_1(
|
||||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1_id(
|
||||
const float * __restrict__ x, void * __restrict__ vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
|
||||
if (ix0 >= kx0_padded) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float4 * x4 = (const float4 *) x;
|
||||
|
||||
const mmid_row_mapping * mapping = (const mmid_row_mapping *)row_mapping;
|
||||
const int64_t ii = mapping[blockIdx.y].i2;
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
const float4 xi = ix0 < kx0 ? x4[(ii*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float amax = fabsf(xi.x);
|
||||
amax = fmaxf(amax, fabsf(xi.y));
|
||||
amax = fmaxf(amax, fabsf(xi.z));
|
||||
amax = fmaxf(amax, fabsf(xi.w));
|
||||
|
||||
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
float sum;
|
||||
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
sum = xi.x + xi.y + xi.z + xi.w;
|
||||
|
||||
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
||||
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
const float d = amax/127.f;
|
||||
const float d_inv = d > 0 ? 1/d : 0.f;
|
||||
char4 q;
|
||||
q.x = roundf(xi.x*d_inv);
|
||||
q.y = roundf(xi.y*d_inv);
|
||||
q.z = roundf(xi.z*d_inv);
|
||||
q.w = roundf(xi.w*d_inv);
|
||||
|
||||
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
||||
char4 * yqs4 = (char4 *) y[ib].qs;
|
||||
yqs4[iqs/4] = q;
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
||||
if (iqs % 16 != 0 || iqs >= 96) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].d2s6[2 + iqs/16] = sum;
|
||||
|
||||
if (iqs % 64 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].d2s6[iqs/64] = d;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (iqs % 32 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
||||
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
||||
} else {
|
||||
y[ib].d4[iqs/32] = d;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
||||
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
||||
@@ -208,6 +300,35 @@ void quantize_mmq_q8_1_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_q8_1_id_cuda(
|
||||
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, kx1, 1);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) {
|
||||
GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1);
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -30,5 +30,9 @@ void quantize_mmq_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_q8_1_id_cuda(
|
||||
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
|
||||
// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream);
|
||||
|
||||
@@ -14,10 +14,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int * all_values = (const int *)iq2k_table;
|
||||
|
||||
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
|
||||
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
|
||||
@@ -35,29 +35,20 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
|
||||
auto values_l = iq2nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 2);
|
||||
auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
|
||||
|
||||
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
|
||||
aux32[0] = (ql >> 0) & 0x03030303;
|
||||
aux32[1] = (ql >> 2) & 0x03030303;
|
||||
aux32[2] = (ql >> 4) & 0x03030303;
|
||||
aux32[3] = (ql >> 6) & 0x03030303;
|
||||
|
||||
const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]);
|
||||
const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]);
|
||||
const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]);
|
||||
const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val1;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val2;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l);
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val1;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val2;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3;
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l);
|
||||
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
||||
@@ -232,6 +232,21 @@ enum vk_device_architecture {
|
||||
INTEL_XE2,
|
||||
};
|
||||
|
||||
// HSK x HSV
|
||||
enum FaHeadSizes {
|
||||
FA_HEAD_SIZE_64,
|
||||
FA_HEAD_SIZE_80,
|
||||
FA_HEAD_SIZE_96,
|
||||
FA_HEAD_SIZE_112,
|
||||
FA_HEAD_SIZE_128,
|
||||
FA_HEAD_SIZE_192,
|
||||
FA_HEAD_SIZE_192_128,
|
||||
FA_HEAD_SIZE_256,
|
||||
FA_HEAD_SIZE_576_512,
|
||||
FA_HEAD_SIZE_UNSUPPORTED,
|
||||
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
||||
};
|
||||
|
||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||
|
||||
@@ -431,6 +446,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_norm_f32;
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
vk_pipeline pipeline_fused_rms_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_back_f32;
|
||||
|
||||
// [src/dst 0=fp32,1=fp16]
|
||||
@@ -441,6 +457,13 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_tanh[2];
|
||||
vk_pipeline pipeline_sigmoid[2];
|
||||
|
||||
// [src/dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_fused_mul_gelu[2];
|
||||
vk_pipeline pipeline_fused_mul_silu[2];
|
||||
vk_pipeline pipeline_fused_mul_relu[2];
|
||||
|
||||
vk_pipeline pipeline_multi_add_f32;
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
@@ -463,26 +486,11 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
@@ -677,6 +685,13 @@ struct vk_op_unary_push_constants {
|
||||
};
|
||||
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
||||
|
||||
struct vk_op_multiadd_push_constants {
|
||||
uint32_t ne;
|
||||
uint32_t ne0, ne1;
|
||||
uint32_t nb0, nb01;
|
||||
uint32_t nadd;
|
||||
};
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
@@ -1670,6 +1685,35 @@ enum FaCodePath {
|
||||
FA_COOPMAT2,
|
||||
};
|
||||
|
||||
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
switch (hsk) {
|
||||
case 64: return FA_HEAD_SIZE_64;
|
||||
case 80: return FA_HEAD_SIZE_80;
|
||||
case 96: return FA_HEAD_SIZE_96;
|
||||
case 112: return FA_HEAD_SIZE_112;
|
||||
case 128: return FA_HEAD_SIZE_128;
|
||||
case 192:
|
||||
if (hsv == 192) {
|
||||
return FA_HEAD_SIZE_192;
|
||||
} else if (hsv == 128) {
|
||||
return FA_HEAD_SIZE_192_128;
|
||||
} else {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
case 256: return FA_HEAD_SIZE_256;
|
||||
case 576:
|
||||
if (hsv == 512) {
|
||||
return FA_HEAD_SIZE_576_512;
|
||||
} else {
|
||||
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
||||
}
|
||||
}
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
@@ -1690,8 +1734,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
GGML_UNUSED(hsv);
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
@@ -1715,7 +1760,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || D == 256) {
|
||||
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||
return {64, 32};
|
||||
}
|
||||
return {64, 64};
|
||||
@@ -2008,19 +2053,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
@@ -2029,26 +2076,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
@@ -2653,6 +2703,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data, "main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -2745,6 +2796,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_UNARY(sigmoid)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_silu[0], "fused_mul_silu_f32", fused_mul_silu_f32_len, fused_mul_silu_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_silu[1], "fused_mul_silu_f16", fused_mul_silu_f16_len, fused_mul_silu_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_gelu[0], "fused_mul_gelu_f32", fused_mul_gelu_f32_len, fused_mul_gelu_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_gelu[1], "fused_mul_gelu_f16", fused_mul_gelu_f16_len, fused_mul_gelu_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_relu[0], "fused_mul_relu_f32", fused_mul_relu_f32_len, fused_mul_relu_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_fused_mul_relu[1], "fused_mul_relu_f16", fused_mul_relu_f16_len, fused_mul_relu_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_multi_add_f32, "multi_add_f32", multi_add_f32_len, multi_add_f32_data, "main", 2, sizeof(vk_op_multiadd_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -3640,7 +3700,6 @@ GGML_CALL void ggml_vk_instance_init() {
|
||||
|
||||
}
|
||||
|
||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||
|
||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||
@@ -5506,6 +5565,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
const uint64_t nei1 = ids->ne[1];
|
||||
if (nei0*nei1 > 4096) {
|
||||
fprintf(stderr, "%s: nei0 = %d, nei1 = %d\n", __func__, (int)nei0, (int)nei1);
|
||||
}
|
||||
GGML_ASSERT(nei0 * nei1 <= 4096);
|
||||
|
||||
const uint32_t nbi1 = ids->nb[1];
|
||||
@@ -5901,28 +5963,74 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
||||
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
||||
} else {
|
||||
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
||||
// Split based on number of ids, to fit in shared memory
|
||||
const uint32_t nei0 = (uint32_t)src2->ne[0];
|
||||
const uint32_t nei1 = (uint32_t)src2->ne[1];
|
||||
|
||||
GGML_ASSERT(nei0 <= 4096);
|
||||
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
||||
|
||||
ggml_tensor src1_copy = *src1;
|
||||
ggml_tensor src2_copy = *src2;
|
||||
ggml_tensor dst_copy = *dst;
|
||||
|
||||
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
||||
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
||||
|
||||
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
||||
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
||||
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
||||
|
||||
src1_copy.ne[2] = n_tokens;
|
||||
src2_copy.ne[1] = n_tokens;
|
||||
dst_copy.ne[2] = n_tokens;
|
||||
|
||||
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||
|
||||
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||
|
||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
||||
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2;
|
||||
const uint32_t kshstride = hsk / 4 + 2;
|
||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * sizeof(float);
|
||||
@@ -5930,7 +6038,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
@@ -5954,11 +6062,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
||||
|
||||
const uint32_t D = neq0;
|
||||
const uint32_t HSK = nek0;
|
||||
const uint32_t HSV = nev0;
|
||||
uint32_t N = neq1;
|
||||
const uint32_t KV = nek1;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne0 == HSV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
@@ -5966,12 +6075,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
GGML_ASSERT(neq0 == HSK);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
GGML_ASSERT(nev1 == nek1);
|
||||
|
||||
@@ -5992,7 +6098,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
@@ -6045,47 +6151,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||
if (path == FA_SCALAR &&
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
||||
small_rows = true;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
@@ -6115,7 +6199,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
@@ -6125,9 +6209,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows).
|
||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
@@ -6246,7 +6330,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
||||
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
@@ -6381,11 +6465,42 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_rms_norm_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_fused_rms_norm_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rms_norm_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
||||
(src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
|
||||
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
||||
(src0->type != dst->type) || (src1->type != dst->type)) {
|
||||
return nullptr;
|
||||
} else {
|
||||
ggml_unary_op unary_op = (ggml_unary_op)dst->op_params[0];
|
||||
switch (unary_op) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return ctx->device->pipeline_fused_mul_silu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_GELU:
|
||||
return ctx->device->pipeline_fused_mul_gelu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_RELU:
|
||||
return ctx->device->pipeline_fused_mul_relu[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
case GGML_OP_MULTI_ADD:
|
||||
if (src0->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F32 ||
|
||||
dst->ne[2] == 1 || dst->ne[3] == 1) {
|
||||
return ctx->device->pipeline_multi_add_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UNARY:
|
||||
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
||||
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
||||
@@ -6521,7 +6636,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -6751,6 +6868,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
|
||||
break;
|
||||
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
|
||||
break;
|
||||
|
||||
case GGML_OP_SUM:
|
||||
// We use GGML_OP_SUM_ROWS with 1 row.
|
||||
elements = { 1, 1, 1 };
|
||||
@@ -6818,6 +6939,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_UNARY:
|
||||
{
|
||||
uint32_t ne = ggml_nelements(dst);
|
||||
@@ -7173,6 +7296,24 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
op_params[0], 0.0f, 0,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||
@@ -7182,6 +7323,19 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_fused_mul_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_MUL_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
uint32_t nadd = (uint32_t)dst->op_params[0];
|
||||
ggml_vk_op_f32<vk_op_multiadd_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MULTI_ADD,
|
||||
{ (uint32_t)ggml_nelements(dst), (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)(dst->nb[1]/sizeof(float)), (uint32_t)(src0->nb[1]/sizeof(float)), nadd }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
int32_t * op_params = (int32_t *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
||||
@@ -8366,6 +8520,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_GET_ROWS:
|
||||
@@ -8386,6 +8542,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -8444,8 +8601,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
@@ -8550,10 +8710,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
ggml_vk_fused_mul_unary(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
ggml_vk_multi_add(ctx, compute_ctx, src0, node, dryrun);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
@@ -8703,6 +8873,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -8725,6 +8896,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
@@ -9408,15 +9581,36 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FUSED_MUL_UNARY:
|
||||
switch ((ggml_unary_op)op->op_params[0]) {
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
||||
(op->src[0]->type == op->type) && (op->src[1]->type == op->type);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
const vk_device& device = ctx->device;
|
||||
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
||||
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
||||
return false;
|
||||
if (op->op == GGML_OP_MUL_MAT_ID) {
|
||||
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
||||
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
||||
return false;
|
||||
}
|
||||
// Check against size of shared memory variable
|
||||
if (op->src[2]->ne[0] > 4096) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -9473,15 +9667,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
{
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
bool coopmat2 = ctx->device->coopmat2;
|
||||
switch (op->src[0]->ne[0]) {
|
||||
case 64:
|
||||
case 80:
|
||||
case 96:
|
||||
case 112:
|
||||
case 128:
|
||||
case 256:
|
||||
break;
|
||||
default:
|
||||
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
||||
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
@@ -9625,6 +9812,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_FUSED_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
@@ -10064,6 +10252,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
|
||||
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
||||
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_FUSED_RMS_NORM) {
|
||||
tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
||||
const float eps = ((float *) tensor->op_params)[0];
|
||||
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
||||
@@ -10129,6 +10319,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_FUSED_MUL_UNARY) {
|
||||
tensor_clone = ggml_fused_mul_unary(ggml_ctx, src_clone[0], src_clone[1], (ggml_unary_op)tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_MULTI_ADD) {
|
||||
tensor_clone = ggml_multi_add(ggml_ctx, src_clone[0], tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
||||
if (src1 == nullptr) {
|
||||
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][D / 4];
|
||||
shared vec4 Qf[Br][HSK / 4];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
@@ -53,18 +54,18 @@ void main() {
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
vec4 Of[Br][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
vec4 Of[Br][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = vec4(0.0);
|
||||
}
|
||||
@@ -112,7 +113,7 @@ void main() {
|
||||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
@@ -191,14 +192,14 @@ void main() {
|
||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = eMf[r] * Of[r][d];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
@@ -255,7 +256,7 @@ void main() {
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = eMf * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
@@ -277,11 +278,11 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
@@ -289,7 +290,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
@@ -305,18 +306,18 @@ void main() {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
@@ -326,9 +327,9 @@ void main() {
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (i * Br + r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
|
||||
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
#include "types.comp"
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t row_split = 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
@@ -74,18 +76,18 @@ void main() {
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
@@ -127,10 +129,10 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t c = (idx + tid) / (D / 4);
|
||||
if (c < Bc && d < D / 4) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
@@ -145,14 +147,14 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
@@ -202,7 +204,7 @@ void main() {
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
}
|
||||
@@ -217,7 +219,7 @@ void main() {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
@@ -280,7 +282,7 @@ void main() {
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
@@ -300,11 +302,11 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
@@ -312,7 +314,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
@@ -328,18 +330,18 @@ void main() {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
@@ -349,9 +351,9 @@ void main() {
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (i * Br + tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c < D) {
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
if (r < N && c < HSV) {
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
@@ -86,9 +86,9 @@ void main() {
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||
#endif
|
||||
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
|
||||
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
@@ -104,16 +104,16 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
@@ -135,10 +135,10 @@ void main() {
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
@@ -203,42 +203,42 @@ void main() {
|
||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
@@ -250,18 +250,18 @@ void main() {
|
||||
|
||||
O = Ldiag*O;
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
|
||||
|
||||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
||||
27
ggml/src/vulkan-shaders/fused_mul_gelu.comp
Normal file
27
ggml/src/vulkan-shaders/fused_mul_gelu.comp
Normal file
@@ -0,0 +1,27 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float xi = float(data_a[i]);
|
||||
const float yi = float(data_b[i]);
|
||||
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
|
||||
data_d[i] = D_TYPE(0.5f*xi*yi*(2.0f - 2.0f / (exp(2 * val) + 1)));
|
||||
}
|
||||
22
ggml/src/vulkan-shaders/fused_mul_relu.comp
Normal file
22
ggml/src/vulkan-shaders/fused_mul_relu.comp
Normal file
@@ -0,0 +1,22 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(float(data_b[i])*max(float(data_a[i]), 0));
|
||||
}
|
||||
24
ggml/src/vulkan-shaders/fused_mul_silu.comp
Normal file
24
ggml/src/vulkan-shaders/fused_mul_silu.comp
Normal file
@@ -0,0 +1,24 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float xi = float(data_a[i]);
|
||||
const float yi = float(data_b[i]);
|
||||
data_d[i] = D_TYPE(xi * yi / (1.0f + exp(-xi)));
|
||||
}
|
||||
54
ggml/src/vulkan-shaders/fused_rms_norm.comp
Normal file
54
ggml/src/vulkan-shaders/fused_rms_norm.comp
Normal file
@@ -0,0 +1,54 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = gl_WorkGroupID.x;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint stride_row_a = p.nb01;
|
||||
const uint stride_channel_a = p.nb02;
|
||||
const uint stride_sample_a = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a;
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sumf = FLOAT_TYPE(0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sumf += xi * xi;
|
||||
}
|
||||
|
||||
sum[tid] = sumf;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col]));
|
||||
}
|
||||
}
|
||||
37
ggml/src/vulkan-shaders/multi_add.comp
Normal file
37
ggml/src/vulkan-shaders/multi_add.comp
Normal file
@@ -0,0 +1,37 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
uint ne0;
|
||||
uint ne1;
|
||||
uint nb1;
|
||||
uint nb01;
|
||||
uint nadd;
|
||||
} p;
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint i1 = i / p.ne0;
|
||||
uint i0 = i % p.ne0;
|
||||
|
||||
float sum = data_a[i1*p.nb01 + i0] + data_a[i1*p.nb01 + i0 + p.ne0];
|
||||
for (int j = 2; j < p.nadd; ++j) sum += data_a[i1*p.nb01 + i0 + j*p.ne0];
|
||||
|
||||
data_d[i1*p.nb1 + i0] = D_TYPE(sum);
|
||||
}
|
||||
@@ -498,6 +498,7 @@ void process_shaders() {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -571,6 +572,15 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("fused_mul_gelu_f16", "fused_mul_gelu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_gelu_f32", "fused_mul_gelu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_silu_f16", "fused_mul_silu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_silu_f32", "fused_mul_silu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("fused_mul_relu_f16", "fused_mul_relu.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("fused_mul_relu_f32", "fused_mul_relu.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
||||
@@ -110,6 +110,8 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
||||
@@ -427,6 +427,7 @@ struct llm_tokenizer_bpe {
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
@@ -477,6 +478,13 @@ struct llm_tokenizer_bpe {
|
||||
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
||||
339
src/llama.cpp
339
src/llama.cpp
@@ -236,6 +236,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_COHERE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -293,6 +294,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -1624,6 +1626,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
}
|
||||
},
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
@@ -1669,6 +1693,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_BITNET,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -1706,6 +1731,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
|
||||
};
|
||||
|
||||
@@ -2602,6 +2628,7 @@ enum e_model {
|
||||
MODEL_27B,
|
||||
MODEL_17B_16E,
|
||||
MODEL_17B_128E,
|
||||
MODEL_80B_A13B,
|
||||
};
|
||||
|
||||
static const size_t kiB = 1024;
|
||||
@@ -5236,6 +5263,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_27B: return "27B";
|
||||
case MODEL_17B_16E: return "17Bx16E (Scout)";
|
||||
case MODEL_17B_128E: return "17Bx128E (Maverick)";
|
||||
case MODEL_80B_A13B: return "80B.A13B";
|
||||
default: return "?B";
|
||||
}
|
||||
}
|
||||
@@ -6084,6 +6112,17 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_80B_A13B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@@ -6354,6 +6393,14 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "bailingmoe") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "seed-coder") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "hunyuan") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
@@ -9260,6 +9307,47 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -9697,12 +9785,7 @@ static struct ggml_tensor * llm_build_norm(
|
||||
const llm_build_cb & cb,
|
||||
int il, float scale_eps = 1) {
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_fused_rms_norm = false;
|
||||
#else
|
||||
constexpr bool use_fused_rms_norm = true;
|
||||
#endif
|
||||
if (use_fused_rms_norm && type == LLM_NORM_RMS && mw) {
|
||||
if (type == LLM_NORM_RMS && mw) {
|
||||
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
|
||||
if (mb) {
|
||||
cb(cur, "fused_norm", il);
|
||||
@@ -9793,13 +9876,7 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
cur = tmp;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
constexpr bool use_fused_mul_unary = false;
|
||||
#else
|
||||
constexpr bool use_fused_mul_unary = true;
|
||||
#endif
|
||||
|
||||
if (use_fused_mul_unary && type_gate == LLM_FFN_PAR &&
|
||||
if (type_gate == LLM_FFN_PAR &&
|
||||
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
|
||||
cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
|
||||
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU);
|
||||
@@ -9981,6 +10058,28 @@ llm_expert_gating_func_type gating_op,
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
}
|
||||
|
||||
//#ifdef GGML_USE_VULKAN
|
||||
// // aggregate experts
|
||||
// ggml_tensor * moe_out = nullptr;
|
||||
// //ggml_tensor * first_expert = nullptr;
|
||||
// for (int i = 0; i < n_expert_used; ++i) {
|
||||
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
// experts->nb[2], i*experts->nb[1]);
|
||||
//
|
||||
// if (i == 0) {
|
||||
// moe_out = cur_expert;
|
||||
// } else {
|
||||
// moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if (n_expert_used == 1) {
|
||||
// // avoid returning a non-contiguous tensor
|
||||
// moe_out = ggml_cont(ctx, moe_out);
|
||||
// }
|
||||
//
|
||||
// return moe_out;
|
||||
//#else
|
||||
if (n_expert_used == 1) {
|
||||
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||
}
|
||||
@@ -9989,32 +10088,8 @@ llm_expert_gating_func_type gating_op,
|
||||
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
|
||||
}
|
||||
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
|
||||
//#endif
|
||||
|
||||
//// aggregate experts
|
||||
//ggml_tensor * moe_out = nullptr;
|
||||
////ggml_tensor * first_expert = nullptr;
|
||||
//for (int i = 0; i < n_expert_used; ++i) {
|
||||
// ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||
// experts->nb[2], i*experts->nb[1]);
|
||||
|
||||
// if (i == 0) {
|
||||
// moe_out = cur_expert;
|
||||
// //first_expert = cur_expert;
|
||||
// //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
|
||||
// // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
|
||||
// // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
|
||||
// } else {
|
||||
// moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||
// //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
|
||||
// }
|
||||
//}
|
||||
|
||||
//if (n_expert_used == 1) {
|
||||
// // avoid returning a non-contiguous tensor
|
||||
// moe_out = ggml_cont(ctx, moe_out);
|
||||
//}
|
||||
|
||||
//return moe_out;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_kqv(
|
||||
@@ -16988,8 +17063,7 @@ struct llm_build_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
//auto * inp_attn = build_attn_inp_kv_unified();
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -17038,7 +17112,7 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
//cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
@@ -17098,7 +17172,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
@@ -17119,12 +17193,162 @@ struct llm_build_context {
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_hunyuan_moe() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_norm", il);
|
||||
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_norm", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = llm_build_norm(ctx0,ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network (non-MoE)
|
||||
ggml_tensor * cur_mlp = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur_mlp, "ffn_mlp", il);
|
||||
|
||||
// MoE branch
|
||||
ggml_tensor * cur_moe = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
true, // norm_topk_prob
|
||||
false,
|
||||
0.0,
|
||||
LLM_EXPERT_GATING_FUNC_SOFTMAX,
|
||||
cb,
|
||||
il);
|
||||
cb(cur_moe, "ffn_moe_out", il);
|
||||
|
||||
ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
|
||||
cb(ffn_out, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, ffn_out, ffn_inp);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
//res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
//res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
llama_batch dummy;
|
||||
dummy.n_tokens = 0;
|
||||
@@ -17422,6 +17646,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_dots1();
|
||||
} break;
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
{
|
||||
result = llm.build_hunyuan_moe();
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -21195,6 +21423,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
@@ -23010,6 +23239,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -23443,6 +23674,18 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|response|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
||||
// tencent/Hunyuan-A13B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
||||
} else {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
@@ -23650,6 +23893,9 @@ struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* voca
|
||||
}
|
||||
|
||||
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
smpl->last_tokens.clear();
|
||||
smpl->dry_repeat_count.clear();
|
||||
smpl->dry_max_token_repeat.clear();
|
||||
@@ -23675,6 +23921,9 @@ struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl
|
||||
}
|
||||
|
||||
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user