mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 01:19:20 +00:00
quantize_mmq_q8_1_id
This commit is contained in:
@@ -2716,7 +2716,12 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
if (num_src1_rows == 0) continue;
|
||||
size_t mapping_offset = cum_moe_counts[i02];
|
||||
|
||||
{
|
||||
if (use_quantized_src1) {
|
||||
quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset),
|
||||
src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream);
|
||||
src1_row.data = src1_quantized.get();
|
||||
}
|
||||
else {
|
||||
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
||||
dim3 grid_dims(num_src1_rows);
|
||||
k_copy_src_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
||||
@@ -2740,12 +2745,12 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
dst_row.nb[2] = num_src1_rows*nb1;
|
||||
dst_row.nb[3] = num_src1_rows*nb1;
|
||||
|
||||
if (use_quantized_src1) {
|
||||
quantize_mmq_q8_1_cuda((const float *)src1_contiguous.get(), src1_quantized.get(), src1->ne[0], num_src1_rows, 1,
|
||||
src1_padded_num_cols, src0_1->type, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
src1_row.data = src1_quantized.get();
|
||||
}
|
||||
//if (use_quantized_src1) {
|
||||
// quantize_mmq_q8_1_cuda((const float *)src1_contiguous.get(), src1_quantized.get(), src1->ne[0], num_src1_rows, 1,
|
||||
// src1_padded_num_cols, src0_1->type, stream);
|
||||
// CUDA_CHECK(cudaGetLastError());
|
||||
// src1_row.data = src1_quantized.get();
|
||||
//}
|
||||
|
||||
dst_row.data = dst_up_contiguous.get();
|
||||
if (use_quantized_src1) {
|
||||
|
||||
@@ -166,6 +166,98 @@ static __global__ void quantize_mmq_q8_1(
|
||||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1_id(
|
||||
const float * __restrict__ x, void * __restrict__ vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
|
||||
if (ix0 >= kx0_padded) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float4 * x4 = (const float4 *) x;
|
||||
|
||||
const mmid_row_mapping * mapping = (const mmid_row_mapping *)row_mapping;
|
||||
const int64_t ii = mapping[blockIdx.y].i2;
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
const float4 xi = ix0 < kx0 ? x4[(ii*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float amax = fabsf(xi.x);
|
||||
amax = fmaxf(amax, fabsf(xi.y));
|
||||
amax = fmaxf(amax, fabsf(xi.z));
|
||||
amax = fmaxf(amax, fabsf(xi.w));
|
||||
|
||||
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
float sum;
|
||||
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
sum = xi.x + xi.y + xi.z + xi.w;
|
||||
|
||||
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
||||
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
const float d = amax/127.f;
|
||||
const float d_inv = d > 0 ? 1/d : 0.f;
|
||||
char4 q;
|
||||
q.x = roundf(xi.x*d_inv);
|
||||
q.y = roundf(xi.y*d_inv);
|
||||
q.z = roundf(xi.z*d_inv);
|
||||
q.w = roundf(xi.w*d_inv);
|
||||
|
||||
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
||||
char4 * yqs4 = (char4 *) y[ib].qs;
|
||||
yqs4[iqs/4] = q;
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
||||
if (iqs % 16 != 0 || iqs >= 96) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].d2s6[2 + iqs/16] = sum;
|
||||
|
||||
if (iqs % 64 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].d2s6[iqs/64] = d;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (iqs % 32 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
||||
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
||||
} else {
|
||||
y[ib].d4[iqs/32] = d;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
||||
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
||||
@@ -208,6 +300,35 @@ void quantize_mmq_q8_1_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_mmq_q8_1_id_cuda(
|
||||
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, kx1, 1);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||
quantize_mmq_q8_1_id<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, row_mapping, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) {
|
||||
GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1);
|
||||
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -30,5 +30,9 @@ void quantize_mmq_q8_1_cuda(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
|
||||
void quantize_mmq_q8_1_id_cuda(
|
||||
const float * x, void * vy, const char * row_mapping, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded,
|
||||
const ggml_type type_x, cudaStream_t stream);
|
||||
|
||||
// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1
|
||||
void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream);
|
||||
|
||||
Reference in New Issue
Block a user