mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 10:51:51 +00:00
WIP
This commit is contained in:
@@ -21555,8 +21555,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||
int nth_k = n_tasks/gcd_k;
|
||||
int rk2 = q->ne[2]/k->ne[2];
|
||||
if (rk2%nth_k == 0) {
|
||||
size_t size = Dk*sizeof(ggml_half)*rk2/nth_k;
|
||||
size += (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
|
||||
size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
|
||||
if (ggml_is_quantized(k->type)) {
|
||||
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
|
||||
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
|
||||
size += q->ne[2]*row_size;
|
||||
}
|
||||
cur = MAX(cur, size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
// => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not
|
||||
// writing onto the same cache line.
|
||||
auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float);
|
||||
auto result_buffer = work + rk2/nth_k*Dk*sizeof(uint16_t);
|
||||
auto result_buffer = work;
|
||||
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
|
||||
//printf("Thread %d: computing k,v = %d, q = %d...%d\n", ith, ith_k*(nek1/gcd_k), ith_q*(rk2/nth_k), ith_q*(rk2/nth_k)+rk2/nth_k-1);
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
|
||||
Reference in New Issue
Block a user