diff --git a/kt-kernel/cuda/custom_gguf/dequant.cu b/kt-kernel/cuda/custom_gguf/dequant.cu index c579469..a567a29 100644 --- a/kt-kernel/cuda/custom_gguf/dequant.cu +++ b/kt-kernel/cuda/custom_gguf/dequant.cu @@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: diff --git a/kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu b/kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu index 8d80160..e77e007 100644 --- a/kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu +++ b/kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu @@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: @@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i //data_gpu.copy_(data, false); // Create output tensor - auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device)); + auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device)); switch (target_dtype) { case torch::kFloat16: