DeepSeek TG optimizations for TG (#928)

* Fuse concat and copy into K cache
* Avoid ggml_cont() when n_token = 1

Combined effect: about +2% in TG performance with full GPU offload

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-10 09:52:07 +02:00
committed by GitHub
parent eea6cc4433
commit adba641347
4 changed files with 79 additions and 5 deletions

View File

@@ -3224,7 +3224,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cuda_op_concat(ctx, dst);
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
cgraph->nodes[i+2]->op == GGML_OP_CPY &&
ggml_cuda_concat_cpy(ctx, dst, cgraph->nodes[i+2])) {
i += 2;
} else {
ggml_cuda_op_concat(ctx, dst);
}
break;
case GGML_OP_UPSCALE:
ggml_cuda_op_upscale(ctx, dst);

View File

@@ -686,3 +686,65 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1,
#endif
return true;
}
template <typename src_t, typename dst_t>
static __global__ void concat_cpy(const char * csrc1, const char * csrc2, char * cdst, int ne1, int ne,
char ** dest_ptrs, int copy_index) {
auto dst = (dst_t *)(dest_ptrs ? dest_ptrs[copy_index] : cdst);
auto src1 = (const src_t *)csrc1;
auto src2 = (const src_t *)csrc2;
for (int i = threadIdx.x; i < ne; i += blockDim.x) {
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
dst[i] = __float2bfloat16(i < ne1 ? src1[i] : src2[i - ne1]);
} else {
dst[i] = (dst_t)(i < ne1 ? src1[i] : src2[i - ne1]);
}
}
}
template <typename src_t, typename dst_t>
static void ggml_concat_cpy_cuda(const char * src1, const char * src2, char * dst, int ne1, int ne, cudaStream_t stream,
char ** dest_ptrs, int& copy_index) {
int block_dim = std::min(ne, 768);
concat_cpy<src_t, dst_t><<<1, block_dim, 0, stream>>>(src1, src2, dst, ne1, ne, dest_ptrs, copy_index);
++copy_index;
}
bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst,
[[maybe_unused]] bool disable_indirection) {
if (dst->type != GGML_TYPE_F16 && dst->type != GGML_TYPE_BF16) return false;
//if (ggml_nrows(dst) > 1) return false;
if (dst->src[0] != concat) return false;
if (ggml_nrows(concat->src[0]) != 1 || ggml_nrows(concat->src[1]) != 1) return false;
if (concat->src[0]->type != GGML_TYPE_F32 || concat->src[1]->type != GGML_TYPE_F32) return false;
if (dst->ne[0] != concat->src[0]->ne[0] + concat->src[1]->ne[0]) return false;
char ** dest_ptrs = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
dest_ptrs = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
}
#endif
if (dst->type == GGML_TYPE_F16) {
ggml_concat_cpy_cuda<float, half>((const char *)concat->src[0]->data, (const char *)concat->src[1]->data,
(char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index);
} else {
ggml_concat_cpy_cuda<float, nv_bfloat16>((const char *)concat->src[0]->data, (const char *)concat->src[1]->data,
(char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index);
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
}
#endif
return true;
}

View File

@@ -12,3 +12,6 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
ggml_tensor * dst1, ggml_tensor * dst2, bool disable_indirection = false);
bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst,
bool disable_indirection = false);