Slight tweaks to SD

This commit is contained in:
turboderp
2023-12-25 16:54:13 +01:00
parent 5135f32dfa
commit f0c516d7c0
3 changed files with 15 additions and 4 deletions

View File

@@ -14,10 +14,12 @@ from exllamav2.generator import (
ExLlamaV2Sampler
)
import time
import time, torch
# Initialize model and draft model
torch.set_num_threads(1)
# model_directory = "/mnt/str/models/codellama-34b-instruct-exl2/4.0bpw"
model_directory = "/mnt/str/models/_gptq/TheBloke_Phine-CodeLlama-34B-v2-GPTQ/"
draft_directory = "/mnt/str/models/tinyllama-1b-ckpt503-exl2/3.5bpw"
@@ -44,7 +46,7 @@ tokenizer = ExLlamaV2Tokenizer(model_config)
# Initialize generators
normal_generator = ExLlamaV2StreamingGenerator(model, model_cache, tokenizer)
speculative_generator = ExLlamaV2StreamingGenerator(model, model_cache, tokenizer, draft, draft_cache, 5)
speculative_generator = ExLlamaV2StreamingGenerator(model, model_cache, tokenizer, draft, draft_cache, num_speculative_tokens = 5)
# Make sure CUDA is initialized so we can measure performance
@@ -96,12 +98,12 @@ def test_gen(generator, prompt, settings, max_new_tokens):
# Settings
gen_prompt = "Here is a simple Quicksort implementation in C++:"
# gen_prompt = "What's the best way to learn a new language?"
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.temperature = 0.6
gen_settings.top_k = 50
gen_settings.top_p = 0.6
gen_settings.top_a = 0.0
gen_settings.token_repetition_penalty = 1.0
gen_settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])

View File

@@ -1219,6 +1219,14 @@ void fast_fadd_cpu(torch::Tensor a, torch::Tensor b)
}
}
void fast_copy_cpu(torch::Tensor a, torch::Tensor b)
{
size_t size_a = a.numel() * torch::elementSize(torch::typeMetaToScalarType(a.dtype()));
size_t size_b = b.numel() * torch::elementSize(torch::typeMetaToScalarType(b.dtype()));
TORCH_CHECK(size_a == size_b, "a and b are not the same size");
memcpy(a.data_ptr(), b.data_ptr(), size_a);
}
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
@@ -1256,6 +1264,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("gemm_half_half_half", &gemm_half_half_half, "gemm_half_half_half");
m.def("fast_fill_cpu_ones_bool", &fast_fill_cpu_ones_bool, "fast_fill_cpu_ones_bool");
m.def("fast_fadd_cpu", &fast_fadd_cpu, "fast_fadd_cpu");
m.def("fast_copy_cpu", &fast_copy_cpu, "fast_copy_cpu");
// m.def("array_fp16_to_fp8_ref", &array_fp16_to_fp8_ref, "array_fp16_to_fp8_ref");
// m.def("array_fp8_to_fp16_ref", &array_fp8_to_fp16_ref, "array_fp8_to_fp16_ref");
}

View File

@@ -66,7 +66,7 @@ class ExLlamaV2Sampler:
c.token_repetition_penalty = self.token_repetition_penalty
c.token_repetition_range = self.token_repetition_range
c.token_repetition_decay = self.token_repetition_decay
c.token_bias = self.token_bias
c.token_bias = None
c.filters = []
return c