Optimization, wider loads in GPTQ kernel (int2) working

This commit is contained in:
turboderp
2023-09-07 04:07:13 +02:00
parent f259fafda9
commit c2f62e1f1f
6 changed files with 53 additions and 31 deletions

View File

@@ -25,8 +25,8 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
# Padding token should embed a zero vector, but sometimes it doesn't (?)
if not torch.is_grad_enabled():
w[pad_id] *= 0
# if not torch.is_grad_enabled():
# w[pad_id] *= 0
self.embedding = nn.Embedding(vocab_size, hidden_size, pad_token_id, device ="meta")
self.embedding.weight = w

View File

@@ -55,6 +55,15 @@ public:
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
};
#endif

View File

@@ -14,7 +14,7 @@
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_KN_SIZE 256
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
@@ -34,16 +34,16 @@ void gemm_half_q_half_cuda_part
bool clear
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 2);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
kernel<<<gridDim, blockDim>>>
@@ -70,6 +70,14 @@ void gemm_half_q_half_cuda_part
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 2);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
// DBGX((uint64_t) b->cuda_q_perm);

View File

@@ -90,7 +90,7 @@ __global__ void gemm_half_q_half_gptq_kernel
*((uint32_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
// __syncthreads();
__syncthreads();
// Find initial group
@@ -107,10 +107,12 @@ __global__ void gemm_half_q_half_gptq_kernel
// Initial group
int zeros[2];
half2 z1z16[2][2];
half2 y1y16[2][2];
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n) + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n + 1) + 1, b_gptq_scales_.item(group, n), z1z16[1], y1y16[1]);
b_gptq_qzeros_.item2(zeros, group, n);
dequant_4bit_8_prep_zero_scale(zeros[0] + 1, b_gptq_scales_.item(group, n ), z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero_scale(zeros[1] + 1, b_gptq_scales_.item(group, n + 1), z1z16[1], y1y16[1]);
__syncthreads();
@@ -126,8 +128,9 @@ __global__ void gemm_half_q_half_gptq_kernel
if (k == nextgroup)
{
group++;
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n) + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero_scale(b_gptq_qzeros_.item(group, n + 1) + 1, b_gptq_scales_.item(group, n), z1z16[1], y1y16[1]);
b_gptq_qzeros_.item2(zeros, group, n);
dequant_4bit_8_prep_zero_scale(zeros[0] + 1, b_gptq_scales_.item(group, n), z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero_scale(zeros[1] + 1, b_gptq_scales_.item(group, n + 1), z1z16[1], y1y16[1]);
nextgroup += groupsize;
}
@@ -135,13 +138,11 @@ __global__ void gemm_half_q_half_gptq_kernel
for (int j = 0; j < 4; j++)
{
half2 dq[2][4];
// const int2* b_ptr2 = (int2*) b_ptr;
// int2 load_int2 = *b_ptr2;
// uint32_t load[] = { (uint32_t) load_int2.x, (uint32_t) load_int2.y };
uint32_t load[] = { b_ptr[0], b_ptr[1] };
const int2* b_ptr2 = (int2*) b_ptr;
int2 load_int2 = *b_ptr2;
dequant_4bit_8_gptq(&load[0], dq[0], z1z16[0], y1y16[0], size_n);
dequant_4bit_8_gptq(&load[1], dq[1], z1z16[1], y1y16[1], size_n);
dequant_4bit_8_gptq(load_int2.x, dq[0], z1z16[0], y1y16[0], size_n);
dequant_4bit_8_gptq(load_int2.y, dq[1], z1z16[1], y1y16[1], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8(dq[0], a_ptr + m * a_stride, block_c[m][0]);

View File

@@ -85,7 +85,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t* q,
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
@@ -94,7 +94,7 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q[0];
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;

View File

@@ -28,7 +28,7 @@ with torch.no_grad():
config_quant = ExLlamaV2Config()
# config_quant.model_dir = "/mnt/str/models/_exl2/llama-7b-4.0bpw-h6-exl2/"
# config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
config_quant.model_dir = "/mnt/str/models/llama-7b-4bit-128g/"
# config_quant.model_dir = "/mnt/str/models/_test_models/TheBloke_WizardLM-30B-Uncensored-GPTQ/"
config_quant.prepare()
model_quant = ExLlamaV2(config_quant)
@@ -133,7 +133,7 @@ with torch.no_grad():
module_quant = model_quant.modules_dict[k]
module_quant.load()
# Test that result of multiplication with identity matrix is the same with and without reconstruction
# Test that result of multiplication with identity and random matrix is the same with and without reconstruction
print()
@@ -145,12 +145,16 @@ with torch.no_grad():
module_quant.load()
if isinstance(module_quant, ExLlamaV2Linear):
ident = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
mat = torch.eye(module_quant.in_features, dtype = torch.half).cuda()
test1 = module_quant.forward(mat, force_cuda = True)
test2 = module_quant.forward(mat, force_recons = True)
diff_i = torch.max((test1 - test2).abs())
test1 = module_quant.forward(ident, force_cuda = True)
test2 = module_quant.forward(ident, force_recons = True)
mat = torch.randn((module_quant.in_features, module_quant.in_features), dtype = torch.half).cuda()
test1 = module_quant.forward(mat, force_cuda = True)
test2 = module_quant.forward(mat, force_recons = True)
diff_r = F.mse_loss(test1, test2)
diff = torch.max((test1 - test2).abs())
print (f"{k:40} {diff.item():.4f}")
print (f"{k:40} ident: {diff_i.item():.6f} u: {diff_r.item():.6f}")
xx = 0