Add CUDA implementation for GGML_OP_CONV_2D and GGML_OP_CONV_2D_DW

This commit is contained in:
Iwan Kawrakow
2025-09-25 18:29:52 +03:00
parent 8732eebc94
commit 879201c26d
6 changed files with 361 additions and 0 deletions

View File

@@ -42,6 +42,8 @@
#include "ggml-cuda/mmq_id.cuh"
#include "ggml-cuda/quantize_id.cuh"
#include "ggml-cuda/topk-moe.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include <algorithm>
#include <array>
@@ -3292,6 +3294,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;

View File

@@ -0,0 +1,161 @@
#include "conv2d-dw.cuh"
struct conv_params {
int in_w, in_h;
int out_w, out_h;
int kernel_w, kernel_h;
int stride_x, stride_y;
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
};
struct kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
kernel_bounds bounds;
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
y * params.out_w + x;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
out_x = global_idx % params.out_w;
out_y = (global_idx / params.out_w) % params.out_h;
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
n = global_idx / (params.out_w * params.out_h * params.channels);
}
};
struct cwhn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return (ky * params.kernel_w + kx) * params.channels + c;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
x * params.channels + c;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
c = global_idx % params.channels;
out_x = (global_idx / params.channels) % params.out_w;
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
n = global_idx / (params.channels * params.out_w * params.out_h);
}
};
template <typename T, typename Layout>
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
const int in_w, const int in_h, const int out_w, const int out_h,
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
const int channels, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = batches * channels * out_h * out_w;
if (global_idx >= total_elements) {
return;
}
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
int batch_idx, channel_idx, out_y_idx, out_x_idx;
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
T accumulator = 0;
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
accumulator += input_val * kernel_val;
}
}
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
}
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * w_d = (const float *) kernel->data;
const float * x_d = (const float *) input->data;
float * y_d = (float *) dst->data;
const int32_t * p = (const int32_t *) dst->op_params;
const int stride_x = p[0];
const int stride_y = p[1];
const int padding_x = p[2];
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int in_w = input->ne[0];
const int in_h = input->ne[1];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int out_w = dst->ne[0];
const int out_h = dst->ne[1];
const int channels = dst->ne[2];
const int batches = dst->ne[3];
cudaStream_t st = ctx.stream();
const int total = batches * channels * out_h * out_w;
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
if (ggml_is_contiguous(input)) {
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else {
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -0,0 +1,166 @@
#include "conv2d.cuh"
#include "convert.cuh"
struct conv_params {
const int64_t IW, IH;
const int64_t OW, OH;
const int64_t KW, KH;
const int64_t ST_X, ST_Y;
const int64_t PD_X, PD_Y;
const int64_t DL_X, DL_Y;
const int64_t IC, OC;
const int64_t B;
const int64_t TOTAL;
};
struct kernel_bounds {
int64_t y_min, y_max;
int64_t x_min, x_max;
};
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
return (a > b) ? a : b;
}
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
return (a < b) ? a : b;
}
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
kernel_bounds bounds;
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
int64_t kern_coord,
int64_t stride,
int64_t dilation,
int64_t padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
}
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
}
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
}
__device__ static void unpack_indices(int64_t global_idx,
const conv_params & P,
int64_t & n,
int64_t & c,
int64_t & out_y,
int64_t & out_x) {
out_x = global_idx % P.OW;
out_y = (global_idx / P.OW) % P.OH;
c = (global_idx / (P.OW * P.OH)) % P.OC;
n = global_idx / (P.OW * P.OH * P.OC);
}
};
template <typename T, typename Layout>
static __global__ void conv2d_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const conv_params P) {
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= P.TOTAL) {
return;
}
int64_t n, c_out, out_y, out_x;
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
float acc = 0.0f;
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
}
}
}
// [N, OC, OH, OW]
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
}
template <typename T>
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
}
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
}
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
float * K_D = (float *) kernel->data;
const float * X_D = (const float *) input->data;
float * Y_D = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
// same number of input channels
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
cudaStream_t st = ctx.stream();
const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0]; // stride_x
const int ST_Y = p[1]; // stride_y
const int PD_X = p[2]; // padding_x
const int PD_Y = p[3]; // padding_y
const int DL_X = p[4]; // dilation_x
const int DL_Y = p[5]; // dilation_y
// No cwhn
GGML_ASSERT(p[6] == false);
const int IW = input->ne[0]; // input_w
const int IH = input->ne[1]; // input_h
const int OW = dst->ne[0]; // output_w
const int OH = dst->ne[1]; // output_h
const int KW = kernel->ne[0]; // kernel_w
const int KH = kernel->ne[1]; // kernel_h
const int IC = input->ne[2]; // input_channels
const int OC = kernel->ne[3]; // ouptut_chanles
const int B = input->ne[3]; // n_batches
const int64_t total = B * OC * OH * OW;
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
if (kernel->type == GGML_TYPE_F16) {
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
} else {
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
}
}

View File

@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_BLOCK_SIZE 256
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -21,3 +21,19 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
template<typename dst_t, typename src_t>
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
if constexpr (std::is_same_v<dst_t, src_t>) {
return x;
} else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
return float(x);
}
}