mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Port of Qwen3-VL support from mainline (#883)
* Port of Qwen3-VL for latest ik_llama.cpp - convert_hf_to_gguf.py - Not touched, use llama.cpp to convert model instead - sysl and metal support for imrope not added - Vulkan support for imrope not tested - Code not tested * Bugfix n_embd was declared multiple times https://github.com/ikawrakow/ik_llama.cpp/pull/883#issuecomment-3471179655 * Fix n_embd issue with qwen3vl * model.output tensor not required https://github.com/ikawrakow/ik_llama.cpp/pull/883#discussion_r2480388389 * Improved logic for qkv combined tensors59ceaf8fcb (r2480395800)59ceaf8fcb (r2480398187)* Fix n_embd for merge_qkv() + cleaner code https://github.com/ikawrakow/ik_llama.cpp/pull/883#discussion_r2481227395 * Revert TENSOR_NOT_REQUIRED
This commit is contained in:
@@ -259,6 +259,7 @@
|
||||
#define GGML_ROPE_TYPE_NEOX 2
|
||||
#define GGML_ROPE_TYPE_MROPE 8
|
||||
#define GGML_ROPE_TYPE_VISION 24
|
||||
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
|
||||
|
||||
#define GGML_MROPE_SECTIONS 4
|
||||
|
||||
|
||||
@@ -345,7 +345,7 @@ template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@@ -372,17 +372,27 @@ static __global__ void rope_multi(
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
@@ -574,7 +584,7 @@ template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@@ -585,11 +595,11 @@ static void rope_multi_cuda(
|
||||
if (freq_factors == nullptr) {
|
||||
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
} else {
|
||||
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections);
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,6 +677,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -704,11 +715,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_cuda<forward>(
|
||||
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
@@ -815,6 +815,7 @@ struct vk_op_rope_push_constants {
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
int32_t sections[4];
|
||||
uint32_t is_imrope;
|
||||
uint32_t is_back;
|
||||
};
|
||||
|
||||
@@ -6754,6 +6755,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_neox) {
|
||||
@@ -6763,7 +6765,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_rope_neox_f16;
|
||||
}
|
||||
} else if (is_mrope && !is_vision) {
|
||||
} else if ((is_mrope || is_imrope) && !is_vision) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_rope_multi_f32;
|
||||
}
|
||||
@@ -7970,6 +7972,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
}
|
||||
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
@@ -7982,7 +7986,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
||||
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
||||
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
||||
sections[0], sections[1], sections[2], sections[3], backprop
|
||||
sections[0], sections[1], sections[2], sections[3], is_imrope, backprop
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
|
||||
@@ -18417,7 +18417,7 @@ static void ggml_rope_cache_init(
|
||||
}
|
||||
|
||||
static void ggml_mrope_cache_init(
|
||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
||||
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
||||
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||
float * cache, float sin_sign, float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
@@ -18452,14 +18452,26 @@ static void ggml_mrope_cache_init(
|
||||
}
|
||||
|
||||
float theta = theta_t;
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
}
|
||||
else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
if (is_imrope) { // qwen3vl apply interleaved mrope
|
||||
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
||||
theta = theta_h;
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
||||
theta = theta_w;
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
||||
theta = theta_t;
|
||||
} else {
|
||||
theta = theta_e;
|
||||
}
|
||||
} else {
|
||||
if (sector >= sections[0] && sector < sec_w) {
|
||||
theta = theta_h;
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||
theta = theta_w;
|
||||
}
|
||||
else if (sector >= sec_w + sections[2]) {
|
||||
theta = theta_e;
|
||||
}
|
||||
}
|
||||
|
||||
rope_yarn(
|
||||
@@ -18707,6 +18719,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -18745,7 +18758,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
@@ -18893,6 +18906,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
@@ -18931,7 +18945,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||
ggml_mrope_cache_init(
|
||||
p_t, p_h, p_w, p_e, sections, is_vision,
|
||||
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ layout (push_constant) uniform parameter {
|
||||
uint s1;
|
||||
uint s2;
|
||||
int sections[4];
|
||||
uint is_imrope;
|
||||
uint is_back;
|
||||
} p;
|
||||
|
||||
|
||||
@@ -32,17 +32,29 @@ void main() {
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
if (p.is_imrope != 0) {
|
||||
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
} else {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < p.sections[0]) {
|
||||
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= p.sections[0] && sector < sec_w) {
|
||||
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + p.sections[2]) {
|
||||
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
Reference in New Issue
Block a user