Merge branch 'main' into ik/graph_reuse

This commit is contained in:
Kawrakow
2025-11-13 19:27:30 +02:00
committed by GitHub
3 changed files with 67 additions and 68 deletions

View File

@@ -1135,7 +1135,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.flash_attn = false;
return true;
}
if (arg == "-fa" || arg == "--flash-attention") {
if (arg == "-fa" || arg == "--flash-attn") {
CHECK_ARG
std::string next_arg{argv[i]};
for (auto& c : next_arg) c = std::tolower(c);

View File

@@ -79,6 +79,37 @@ static __global__ void rope_norm(
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0;
const int ix = i2*s02 + i1*s01 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = src0[ix + 0];
dst[idst + 1] = src0[ix + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + 1];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_neox(
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
@@ -261,37 +292,6 @@ static __global__ void fused_rms_rope_neox_fast(const float * src0_1, const floa
}
static __global__ void rope_norm_fast(const float * src0, const float * src1, float * dst, int ne0, int ne1, int nelem,
int s01, int s02, int n_dims) {
int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= nelem) {
return;
}
int i2 = i / (ne0*ne1); i -= i2*ne0*ne1;
int i1 = i / ne0;
int i0 = i - i1*ne0;
const int idst = i2*ne0*ne1 + i1*ne0 + i0;
const int ix = i2*s02 + i1*s01 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = src0[ix + 0];
dst[idst + 1] = src0[ix + 1];
return;
}
const float x0 = src0[ix + 0];
const float x1 = src0[ix + 1];
const float cos_theta = src1[i2*ne0 + i0 + 0];
const float sin_theta = src1[i2*ne0 + i0 + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void fused_rope_norm_fast(const float * src0_1, const float * src0_2, const float * src1,
float * dst_1, float * dst_2, int ne0, int ne1_1, int ne1_2, int nelem1, int nelem,
int s01_1, int s02_1, int s01_2, int s02_2, int n_dims) {
@@ -508,7 +508,7 @@ static void rope_neox_fast_cuda(const float * src0, const float * src1, float *
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
rope_neox_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne00*ne01*ne02, s01, s02, n_dims);
}
static void fused_rope_neox_fast_cuda(const float * src0_1, const float * src0_2, const float * src1,
@@ -557,7 +557,7 @@ static void rope_norm_fast_cuda(const float * src0, const float * src1, float *
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int n_blocks = (ne00*ne01*ne02 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(n_blocks, 1, 1);
rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne01*ne02*ne02, s01, s02, n_dims);
rope_norm_fast<<<block_nums, block_dims, 0, stream>>>(src0, src1, dst, ne00, ne01, ne00*ne01*ne02, s01, s02, n_dims);
}
static void rope_multi_fast_cuda(const float * src0, const float * src1, float * dst, int ne00, int ne01, int ne02, int s01, int s02,
@@ -864,7 +864,6 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
@@ -888,14 +887,14 @@ void ggml_cuda_op_rope_fast(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
} else if (is_mrope && !is_vision) {
rope_multi_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
} else if (is_vision) {
rope_vision_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
} else {
//printf("Using norm\n");
rope_norm_fast_cuda(
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, s01, s02, n_dims, nr, stream);
(const float *)src0->data, (const float *)src1->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, stream);
}
}

View File

@@ -1003,11 +1003,10 @@ inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) {
template <int nrc_y>
static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
Q8<nrc_y, block_q8_2_x4> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
int nb = n / QK4_NL;
__m256i v[8];
GGML_ASSERT(nb%4 == 0);
if constexpr (nrc_y == 1) {
union { __m256 vec; float val[8]; } helper;
for (int ix = 0; ix < nrc_x; ix += 8) {
@@ -1026,14 +1025,14 @@ static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const D
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto qy = (const block_q8_1 *)q8.y[0];
auto qy = (const block_q8_2 *)q8.y[0];
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d));
prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
auto sumi = accum_q4_0_quants(v, qy[ib].qs);
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1);
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2);
acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(m8), acc2);
}
acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1);
info.store(ix, 0, acc1);
@@ -1077,12 +1076,12 @@ static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const D
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f));
prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = accum_q4_0_quants(v, qy[ib].qs);
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1101,7 +1100,7 @@ static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
return;
}
GGML_ASSERT(nrc_x%16 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
Q8<nrc_y, block_q8_2_x4> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
int nb = n / QK4_NL;
__m512 acc[2*nrc_y] = {};
@@ -1159,10 +1158,10 @@ static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto sumi = dot(qy[ib].qs);
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto dy = _mm512_set1_ps(d8);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1245,12 +1244,12 @@ static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const D
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq5[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1325,12 +1324,12 @@ static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataIn
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq5l[ib], iq5h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto dy = _mm512_set1_ps(d8);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1415,12 +1414,12 @@ static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const D
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq6[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]);
acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*m8), acc[iy]);
}
}
@@ -1495,12 +1494,12 @@ static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataIn
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = prepare(iq6l[ib], iq6h[ib]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs));
ggml_bf16_t d{qy[ib].d}, s{qy[ib].s};
auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto dy = _mm512_set1_ps(d8);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {