mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-21 15:09:40 +00:00
Work around compiler bug
It issues a warning that there is an extra semicolon outside of a function, but there isn't. If I remove the anonymous namespace and turn the functions inside into static, the warning disapears, so clearly a compiler bug.
This commit is contained in:
@@ -18,17 +18,8 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M_R4> {
|
||||
static constexpr int qi = 4;
|
||||
};
|
||||
|
||||
// Reminder:
|
||||
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
// constexpr int vdr = get_vdr_mmvq(type);
|
||||
|
||||
// QI4_XS = 256/(4*2) = 32
|
||||
// vdr = 4, qi = 32 -> qi/vdr = 8, kqs = 4*(tid%8), blocks_per_iter = 4*1*32/32 = 4
|
||||
// vdr = 2, qi = 32 -> qi/vdr =16, kqs = 2*(tid%16), blocks_per_iter = 2*1*32/32 = 2
|
||||
namespace {
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_mul_mat_vec_q_kerne(
|
||||
static __device__ void iqk_mul_mat_vec_q_kerne(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy,
|
||||
const float * bias, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
|
||||
@@ -110,7 +101,7 @@ __device__ void iqk_mul_mat_vec_q_kerne(
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_fused_mul_mat_vec_q_kernel(
|
||||
static __device__ void iqk_fused_mul_mat_vec_q_kernel(
|
||||
const void * __restrict__ vup, const void * __restrict__ vgate, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -228,7 +219,7 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y,
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_mul_mat_vec_q(
|
||||
static __global__ void iqk_mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const char * __restrict__ ids_data, const void * __restrict__ bias,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -248,7 +239,7 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y,
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_fused_mul_mat_vec_q(
|
||||
static __global__ void iqk_fused_mul_mat_vec_q(
|
||||
const void * __restrict__ vx_u, const void * __restrict__ vx_g, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const char * __restrict__ ids_data, const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -269,7 +260,7 @@ __global__ void iqk_fused_mul_mat_vec_q(
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
|
||||
void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
static void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(args.ncols_x % ggml_blck_size(type) == 0);
|
||||
//GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||
@@ -428,7 +419,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
|
||||
static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
|
||||
int & val1, int & val2) {
|
||||
|
||||
uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
|
||||
@@ -476,7 +467,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t * values) {
|
||||
static __device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t * values) {
|
||||
uint16_t v1 = values[a8[0]] | (values[a8[1]] << 8);
|
||||
uint16_t v2 = values[a8[2]] | (values[a8[3]] << 8);
|
||||
return v1 | (v2 << 16);
|
||||
@@ -506,8 +497,6 @@ __device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ3_K_Q8_1_MMQ 4
|
||||
|
||||
} // namespace
|
||||
|
||||
extern void mul_mat_vec_iq2_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
extern void mul_mat_vec_iq3_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
extern void mul_mat_vec_iq4_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
|
||||
Reference in New Issue
Block a user