This seems to work

This commit is contained in:
Iwan Kawrakow
2025-08-24 16:46:41 +03:00
parent 0c3a3ffee6
commit af6b5365cc
2 changed files with 149 additions and 103 deletions

View File

@@ -39,6 +39,7 @@
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/graph.cuh"
#include "ggml-cuda/mmq_id.cuh"
#include <algorithm>
#include <array>
@@ -2392,6 +2393,10 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr);
return false;
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");

View File

@@ -134,6 +134,40 @@ struct tile_x_sizes {
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11010
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
#define FP16_MMA_AVAILABLE
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#ifdef __CUDACC__
template<typename... Args>
__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {}
@@ -2961,7 +2995,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}
template<int mmq_x, int mmq_y, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
static __device__ __forceinline__ void mmq_write_back_dp4a_id(
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
constexpr int nwarps = mmq_get_nwarps_device();
@@ -2989,7 +3023,7 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
}
template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
static __device__ __forceinline__ void mmq_write_back_mma_id(
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
@@ -3040,10 +3074,10 @@ static __device__ __forceinline__ void mmq_write_back_mma(
// -------------------------------------------------------------------------------------------------------------------------------------
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
struct mmq_type_traits;
struct mmq_type_traits_id;
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
@@ -3051,7 +3085,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
@@ -3059,7 +3093,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3067,7 +3101,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
@@ -3075,7 +3109,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3083,7 +3117,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3091,7 +3125,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
@@ -3099,7 +3133,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
@@ -3107,7 +3141,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
@@ -3115,7 +3149,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
@@ -3123,7 +3157,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
@@ -3131,7 +3165,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3139,7 +3173,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
@@ -3147,7 +3181,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
@@ -3155,7 +3189,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3163,7 +3197,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3171,7 +3205,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
@@ -3179,7 +3213,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3187,7 +3221,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
struct mmq_type_traits_id<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3195,7 +3229,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
};
template <ggml_type type, int mmq_x, bool need_check, bool fixup>
static __device__ __forceinline__ void mul_mat_q_process_tile(
static __device__ __forceinline__ void mul_mat_q_process_tile_id(
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int stride_row_x, const int ncols_y, const int stride_col_dst,
@@ -3205,18 +3239,18 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits_id<mmq_x, mmq_y, need_check, type>::load_tiles;
extern __shared__ int data_mul_mat_q[];
int * tile_y = data_mul_mat_q + mmq_x;
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
constexpr mmq_write_back_t write_back = mmq_write_back_mma_id<type, mmq_x, mmq_y, need_check>;
#else
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a_id<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -3267,7 +3301,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
}
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
// The mul_mat_q_id kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
template <ggml_type type, int mmq_x, bool need_check>
#if defined(GGML_USE_HIP)
@@ -3281,7 +3315,7 @@ template <ggml_type type, int mmq_x, bool need_check>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP)
static __global__ void mul_mat_q(
static __global__ void mul_mat_q_id(
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
@@ -3370,7 +3404,7 @@ static __global__ void mul_mat_q(
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false;
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
return;
@@ -3448,7 +3482,7 @@ static __global__ void mul_mat_q(
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
@@ -3515,14 +3549,14 @@ static __global__ void mul_mat_q(
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
mul_mat_q_process_tile_id<type, mmq_x, need_check, fixup>
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
}
template <ggml_type type, int mmq_x, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
static __global__ void mul_mat_q_stream_k_fixup_id(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
@@ -3673,7 +3707,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
}
}
struct mmq_args {
struct mmq_args_id {
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
@@ -3692,7 +3726,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
}
template <ggml_type type, int mmq_x>
static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
@@ -3704,14 +3738,17 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id<type, mmq_x, false>), nbytes_shared);
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id<type, mmq_x, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
if (args.nchannels_y % args.nchannels_x) {
printf("Oops: args.nchannels_y = %d, args.nchannels_x = %d\n", args.nchannels_y, args.nchannels_x);
}
GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
const int channel_ratio = args.nchannels_y / args.nchannels_x;
@@ -3720,7 +3757,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
if (!args.use_stream_k) {
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3728,7 +3765,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3749,7 +3786,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3760,13 +3797,13 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
return;
}
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
mul_mat_q_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3777,7 +3814,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
return;
}
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
mul_mat_q_stream_k_fixup_id<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
@@ -3785,7 +3822,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
}
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
@@ -3815,52 +3852,52 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
switch (mmq_x_best) {
case 8:
launch_mul_mat_q<type, 8>(ctx, args, stream);
launch_mul_mat_q_id<type, 8>(ctx, args, stream);
break;
case 16:
launch_mul_mat_q<type, 16>(ctx, args, stream);
launch_mul_mat_q_id<type, 16>(ctx, args, stream);
break;
case 24:
launch_mul_mat_q<type, 24>(ctx, args, stream);
launch_mul_mat_q_id<type, 24>(ctx, args, stream);
break;
case 32:
launch_mul_mat_q<type, 32>(ctx, args, stream);
launch_mul_mat_q_id<type, 32>(ctx, args, stream);
break;
case 40:
launch_mul_mat_q<type, 40>(ctx, args, stream);
launch_mul_mat_q_id<type, 40>(ctx, args, stream);
break;
case 48:
launch_mul_mat_q<type, 48>(ctx, args, stream);
launch_mul_mat_q_id<type, 48>(ctx, args, stream);
break;
case 56:
launch_mul_mat_q<type, 56>(ctx, args, stream);
launch_mul_mat_q_id<type, 56>(ctx, args, stream);
break;
case 64:
launch_mul_mat_q<type, 64>(ctx, args, stream);
launch_mul_mat_q_id<type, 64>(ctx, args, stream);
break;
case 72:
launch_mul_mat_q<type, 72>(ctx, args, stream);
launch_mul_mat_q_id<type, 72>(ctx, args, stream);
break;
case 80:
launch_mul_mat_q<type, 80>(ctx, args, stream);
launch_mul_mat_q_id<type, 80>(ctx, args, stream);
break;
case 88:
launch_mul_mat_q<type, 88>(ctx, args, stream);
launch_mul_mat_q_id<type, 88>(ctx, args, stream);
break;
case 96:
launch_mul_mat_q<type, 96>(ctx, args, stream);
launch_mul_mat_q_id<type, 96>(ctx, args, stream);
break;
case 104:
launch_mul_mat_q<type, 104>(ctx, args, stream);
launch_mul_mat_q_id<type, 104>(ctx, args, stream);
break;
case 112:
launch_mul_mat_q<type, 112>(ctx, args, stream);
launch_mul_mat_q_id<type, 112>(ctx, args, stream);
break;
case 120:
launch_mul_mat_q<type, 120>(ctx, args, stream);
launch_mul_mat_q_id<type, 120>(ctx, args, stream);
break;
case 128:
launch_mul_mat_q<type, 128>(ctx, args, stream);
launch_mul_mat_q_id<type, 128>(ctx, args, stream);
break;
default:
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
@@ -3870,27 +3907,27 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
}
#define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
template void mul_mat_q_case_id<type>(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) \
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
DECL_MMQ_CASE(GGML_TYPE_Q4_0);
DECL_MMQ_CASE(GGML_TYPE_Q4_1);
DECL_MMQ_CASE(GGML_TYPE_Q5_0);
DECL_MMQ_CASE(GGML_TYPE_Q5_1);
DECL_MMQ_CASE(GGML_TYPE_Q8_0);
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
DECL_MMQ_CASE(GGML_TYPE_Q2_K);
DECL_MMQ_CASE(GGML_TYPE_Q3_K);
DECL_MMQ_CASE(GGML_TYPE_Q4_K);
DECL_MMQ_CASE(GGML_TYPE_Q5_K);
DECL_MMQ_CASE(GGML_TYPE_Q6_K);
DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
// -------------------------------------------------------------------------------------------------------------------------
@@ -4029,64 +4066,64 @@ static void launch_mmq_ids_helper(
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q4_1>(ctx, args, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q5_0>(ctx, args, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q5_1>(ctx, args, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q3_K>(ctx, args, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q4_K>(ctx, args, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q5_K>(ctx, args, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_Q6_K>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ2_XS>(ctx, args, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ2_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ3_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ1_S>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ4_XS>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
mul_mat_q_case_id<GGML_TYPE_IQ4_NL>(ctx, args, stream);
break;
default:
GGML_ABORT("fatal error");
@@ -4236,14 +4273,18 @@ void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor *
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
const mmq_args_id args = {
src0_d, src0->type, (const int *) src1_q8_1, ids_dst, expert_bounds, dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
//printf("ne00 = %ld, ne01 = %ld, ne_get_rows = %ld, s01 = %ld, s1 = %ld\n", ne00, ne01, ne_get_rows, s01, s1);
//printf("ne02 = %ld, s02 = %ld, s12 = %ld, s2 = %ld\n", ne02, s02, s12, s2);
//printf("ne03 = %ld, s03 = %ld, s13 = %ld, s3 = %ld\n", ne03, s03, s13, s3);
ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream);
}
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {