mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Optimization, wider loads in GPTQ kernel (int2) working
This commit is contained in:
@@ -25,8 +25,8 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
|
||||
|
||||
# Padding token should embed a zero vector, but sometimes it doesn't (?)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
w[pad_id] *= 0
|
||||
# if not torch.is_grad_enabled():
|
||||
# w[pad_id] *= 0
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size, pad_token_id, device ="meta")
|
||||
self.embedding.weight = w
|
||||
|
||||
@@ -55,6 +55,15 @@ public:
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_KN_SIZE 256
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define CLEAR_N_SIZE 256
|
||||
@@ -34,16 +34,16 @@ void gemm_half_q_half_cuda_part
|
||||
bool clear
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 2);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
@@ -70,6 +70,14 @@ void gemm_half_q_half_cuda_part
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 2);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
// DBGX((uint64_t) b->cuda_q_perm);
|
||||
|
||||
@@ -90,7 +90,7 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
*((uint32_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
@@ -107,10 +107,12 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[2];
|
||||
half2 z1z16[2][2];
|
||||
half2 y1y16[2][2];
|
||||
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n) + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n + 1) + 1, b_gptq_scales_.item(group, n), z1z16[1], y1y16[1]);
|
||||
b_gptq_qzeros_.item2(zeros, group, n);
|
||||
dequant_4bit_8_prep_zero_scale(zeros[0] + 1, b_gptq_scales_.item(group, n ), z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero_scale(zeros[1] + 1, b_gptq_scales_.item(group, n + 1), z1z16[1], y1y16[1]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -126,8 +128,9 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n) + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n + 1) + 1, b_gptq_scales_.item(group, n), z1z16[1], y1y16[1]);
|
||||
b_gptq_qzeros_.item2(zeros, group, n);
|
||||
dequant_4bit_8_prep_zero_scale(zeros[0] + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero_scale(zeros[1] + 1, b_gptq_scales_.item(group, n + 1), z1z16[1], y1y16[1]);
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
@@ -135,13 +138,11 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
half2 dq[2][4];
|
||||
// const int2* b_ptr2 = (int2*) b_ptr;
|
||||
// int2 load_int2 = *b_ptr2;
|
||||
// uint32_t load[] = { (uint32_t) load_int2.x, (uint32_t) load_int2.y };
|
||||
uint32_t load[] = { b_ptr[0], b_ptr[1] };
|
||||
const int2* b_ptr2 = (int2*) b_ptr;
|
||||
int2 load_int2 = *b_ptr2;
|
||||
|
||||
dequant_4bit_8_gptq(&load[0], dq[0], z1z16[0], y1y16[0], size_n);
|
||||
dequant_4bit_8_gptq(&load[1], dq[1], z1z16[1], y1y16[1], size_n);
|
||||
dequant_4bit_8_gptq(load_int2.x, dq[0], z1z16[0], y1y16[0], size_n);
|
||||
dequant_4bit_8_gptq(load_int2.y, dq[1], z1z16[1], y1y16[1], size_n);
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8(dq[0], a_ptr + m * a_stride, block_c[m][0]);
|
||||
|
||||
@@ -85,7 +85,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t* q,
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
@@ -94,7 +94,7 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
|
||||
@@ -28,7 +28,7 @@ with torch.no_grad():
|
||||
|
||||
config_quant = ExLlamaV2Config()
|
||||
# config_quant.model_dir = "/mnt/str/models/_exl2/llama-7b-4.0bpw-h6-exl2/"
|
||||
# config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
|
||||
config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
|
||||
# config_quant.model_dir = "/mnt/str/models/_test_models/TheBloke_WizardLM-30B-Uncensored-GPTQ/"
|
||||
config_quant.prepare()
|
||||
model_quant = ExLlamaV2(config_quant)
|
||||
@@ -133,7 +133,7 @@ with torch.no_grad():
|
||||
module_quant = model_quant.modules_dict[k]
|
||||
module_quant.load()
|
||||
|
||||
# Test that result of multiplication with identity matrix is the same with and without reconstruction
|
||||
# Test that result of multiplication with identity and random matrix is the same with and without reconstruction
|
||||
|
||||
print()
|
||||
|
||||
@@ -145,12 +145,16 @@ with torch.no_grad():
|
||||
module_quant.load()
|
||||
if isinstance(module_quant, ExLlamaV2Linear):
|
||||
|
||||
ident = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
|
||||
mat = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
|
||||
test1 = module_quant.forward(mat, force_cuda = True)
|
||||
test2 = module_quant.forward(mat, force_recons = True)
|
||||
diff_i = torch.max((test1 - test2).abs())
|
||||
|
||||
test1 = module_quant.forward(ident, force_cuda = True)
|
||||
test2 = module_quant.forward(ident, force_recons = True)
|
||||
mat = torch.randn((module_quant.in_features, module_quant.in_features), dtype = torch.half).cuda()
|
||||
test1 = module_quant.forward(mat, force_cuda = True)
|
||||
test2 = module_quant.forward(mat, force_recons = True)
|
||||
diff_r = F.mse_loss(test1, test2)
|
||||
|
||||
diff = torch.max((test1 - test2).abs())
|
||||
print (f"{k:40} {diff.item():.4f}")
|
||||
print (f"{k:40} ident: {diff_i.item():.6f} u: {diff_r.item():.6f}")
|
||||
|
||||
xx = 0
|
||||
|
||||
Reference in New Issue
Block a user