|
|
|
|
@@ -53,6 +53,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
|
|
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
|
|
|
case GGML_TYPE_Q5_1:
|
|
|
|
|
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
|
|
|
|
case GGML_TYPE_Q6_0:
|
|
|
|
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
|
|
|
case GGML_TYPE_Q8_0:
|
|
|
|
|
return MMQ_Q8_1_DS_LAYOUT_D4;
|
|
|
|
|
case GGML_TYPE_Q2_K:
|
|
|
|
|
@@ -155,6 +157,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
|
|
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
|
|
|
|
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
|
|
|
|
|
type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
|
|
|
|
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
|
|
|
|
@@ -189,6 +192,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
|
|
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
|
|
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
|
|
type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
|
|
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
|
|
|
|
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
|
|
|
|
@@ -556,6 +560,69 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_0(
|
|
|
|
|
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
|
|
|
|
|
|
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
|
int * x_qs = (int *) x_tile;
|
|
|
|
|
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
|
|
|
|
#else
|
|
|
|
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y);
|
|
|
|
|
int * x_qs = (int *) x_tile;
|
|
|
|
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
|
|
|
#endif // INT8_MMA_AVAILABLE
|
|
|
|
|
|
|
|
|
|
const int kbx = threadIdx.x / QI6_0;
|
|
|
|
|
const int kqsx = threadIdx.x % QI6_0;
|
|
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
|
|
|
int i = i0 + threadIdx.y;
|
|
|
|
|
|
|
|
|
|
if (need_check) {
|
|
|
|
|
i = min(i, i_max);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx;
|
|
|
|
|
|
|
|
|
|
const int ql = get_int_b2(bxi->qs, kqsx);
|
|
|
|
|
const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2);
|
|
|
|
|
|
|
|
|
|
int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030);
|
|
|
|
|
int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030);
|
|
|
|
|
qs0 = __vsubss4(qs0, 0x20202020); // subtract 32
|
|
|
|
|
qs1 = __vsubss4(qs1, 0x20202020); // subtract 32
|
|
|
|
|
|
|
|
|
|
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0;
|
|
|
|
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
|
|
|
|
|
#else
|
|
|
|
|
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0;
|
|
|
|
|
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1;
|
|
|
|
|
#endif // INT8_MMA_AVAILABLE
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int blocks_per_tile_x_row = WARP_SIZE / QI6_0;
|
|
|
|
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) {
|
|
|
|
|
int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row;
|
|
|
|
|
|
|
|
|
|
if (need_check) {
|
|
|
|
|
i = min(i, i_max);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd;
|
|
|
|
|
|
|
|
|
|
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
|
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
|
|
|
#else
|
|
|
|
|
x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d;
|
|
|
|
|
#endif // INT8_MMA_AVAILABLE
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
|
|
|
|
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
|
|
|
|
|
|
@@ -2379,6 +2446,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
|
|
|
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
|
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_0> {
|
|
|
|
|
static constexpr int vdr = VDR_Q6_0_Q8_1_MMQ;
|
|
|
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0<mmq_y, nwarps, need_check>;
|
|
|
|
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
|
|
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
|
|
|
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
|
|
|
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
|
|
|
|
@@ -2910,6 +2985,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q6_0);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
|
|
|
|
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|
|
|
|
|