diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 3ed549df..2c36c1c0 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -75,10 +75,10 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor * auto V = dst->src[2]; GGML_ASSERT(Q->ne[2] / K->ne[2] == 12); - ggml_cuda_pool_alloc k_data(ctx.pool()); - ggml_cuda_pool_alloc v_data(ctx.pool()); ggml_cuda_pool_alloc q_data(ctx.pool(), ggml_nelements(Q)); ggml_cuda_pool_alloc dst_data(ctx.pool(), ggml_nelements(dst)); + ggml_cuda_pool_alloc k_data(ctx.pool()); + ggml_cuda_pool_alloc v_data(ctx.pool()); repack_q(Q, q_data.get(), 8, 4, K->ne[2], ctx.stream()); @@ -102,6 +102,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor * to_fp_16(K->data, k_data.get(), 1, nelem, ctx.stream()); local_K.type = GGML_TYPE_F16; local_K.data = k_data.get(); + auto ts = ggml_type_size(K->type); + auto bs = ggml_blck_size(K->type); + local_K.nb[0] = sizeof(half); + local_K.nb[1] = sizeof(half)*bs * local_K.nb[1]/ts; + local_K.nb[2] = sizeof(half)*bs * local_K.nb[2]/ts; + local_K.nb[3] = sizeof(half)*bs * local_K.nb[3]/ts; } if (V->type != GGML_TYPE_F16) { auto nelem = ggml_nelements(V); @@ -110,6 +116,12 @@ static void glm45_flash_attention(ggml_backend_cuda_context & ctx, ggml_tensor * to_fp_16(V->data, v_data.get(), 1, nelem, ctx.stream()); local_V.type = GGML_TYPE_F16; local_V.data = v_data.get(); + auto ts = ggml_type_size(V->type); + auto bs = ggml_blck_size(V->type); + local_V.nb[0] = sizeof(half); + local_V.nb[1] = sizeof(half)*bs * local_V.nb[1]/ts; + local_V.nb[2] = sizeof(half)*bs * local_V.nb[2]/ts; + local_V.nb[3] = sizeof(half)*bs * local_V.nb[3]/ts; } constexpr int n_op_params = GGML_MAX_OP_PARAMS / sizeof(int);