diff --git a/model.py b/model.py index aca6bdc..0bcc87d 100644 --- a/model.py +++ b/model.py @@ -239,6 +239,7 @@ class ModelContainer: **kwargs: 'token_healing' (bool): Use token healing (default: False) 'temperature' (float): Sampling temperature (default: 1.0) + 'temperature_last' (bool): Apply temperature after all other samplers (default: False) 'top_k' (int): Sampling top-K (default: 0) 'top_p' (float): Sampling top-P (default: 1.0) 'min_p' (float): Sampling min-P (default: 0.0) @@ -270,6 +271,7 @@ class ModelContainer: gen_settings = ExLlamaV2Sampler.Settings() gen_settings.temperature = kwargs.get("temperature", 1.0) + gen_settings.temperature_last = kwargs.get("temperature_last", False) gen_settings.top_k = kwargs.get("top_k", 1) gen_settings.top_p = kwargs.get("top_p", 1.0) gen_settings.min_p = kwargs.get("min_p", 0.0) diff --git a/model_test.py b/model_test.py index bb9ef72..d2490c6 100644 --- a/model_test.py +++ b/model_test.py @@ -2,20 +2,21 @@ from model import ModelContainer def progress(module, modules): - print(f"Loaded {module}/{modules} modules") - yield + yield module, modules -mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/", max_seq_len = 100) -mc.load(progress) +container = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/") +loader = container.load_gen(progress) +for (module, modules) in loader: + print(module, modules) -gen = mc.generate_gen("Once upon a tim", generate_window = 16, token_healing = True) -for g in gen: print(g, end = "") +generator = container.generate_gen("Once upon a tim", token_healing = True) +for g in generator: print(g, end = "") -mc.unload() -del mc +container.unload() +del container mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/") mc.load(progress) -response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1) +response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1, max_new_tokens = 1000, stream_interval = 0.5) print (response)