Rework chat history logic

This commit is contained in:
turboderp
2023-11-13 09:17:36 +01:00
parent 89755b5852
commit da403bc7a4
2 changed files with 143 additions and 57 deletions

View File

@@ -123,6 +123,8 @@ class Session:
settings: {} = None
# mode: str
history_first = 0
def __init__(self, session_uuid = None):
self.session_uuid = session_uuid
@@ -193,62 +195,142 @@ class Session:
return new_block
def create_context(self, prompt_format, max_len, prefix = ""):
def create_context(self, prompt_format, max_len, min_len, prefix = ""):
if prompt_format.is_instruct():
prompts = []
responses = []
# Create prompt-response pairs, pad in case of multiple prompts or responses in a row
for h in self.history:
if h["author"] == "assistant":
if len(prompts) == len(responses): prompts.append("")
responses.append(h["text"])
elif h["author"] == "user":
if len(prompts) != len(responses): responses.append("")
prompts.append(h["text"])
else:
print("Unknown author")
# Create context until we run out of space
while True:
context_str = ""
for turn in range(len(prompts)):
p = prompts[turn]
r = responses[turn] if turn < len(responses) else None
sp = self.settings["system_prompt"] if context_str == "" else None
up_text = prompt_format.format(p, r, sp, self.settings)
context_str = context_str + up_text
context_str += prefix
context_ids = get_loaded_model().tokenizer.encode(context_str, encode_special_tokens = prompt_format.encode_special_tokens())
if context_ids.shape[-1] < max_len: return context_str, context_ids
prompts = prompts[1:]
responses = responses[1:]
# Non-instruct format
return self.create_context_instruct(prompt_format, max_len, min_len, prefix)
else:
return self.create_context_raw(prompt_format, max_len, min_len, prefix)
history_copy = self.history
while True:
context_str = self.settings["system_prompt"] + "\n" + "\n".join([h["text"] for h in history_copy]) + "\n"
context_str += prefix
context_ids = get_loaded_model().tokenizer.encode(context_str, encode_special_tokens = prompt_format.encode_special_tokens())
def create_context_instruct(self, prompt_format, max_len, min_len, prefix = ""):
if context_ids.shape[-1] < max_len: return context_str, context_ids
history_copy = history_copy[1:]
tokenizer = get_loaded_model().tokenizer
prompts = []
responses = []
# Create prompt-response pairs, pad in case of multiple prompts or responses in a row
for h in self.history:
if h["author"] == "assistant":
if len(prompts) == len(responses): prompts.append("")
responses.append(h["text"])
elif h["author"] == "user":
if len(prompts) != len(responses): responses.append("")
prompts.append(h["text"])
else:
print("Unknown author")
# Get relative length of system prompt
p1 = prompt_format.format("", None, None, self.settings)
p2 = prompt_format.format("", "", self.settings["system_prompt"], self.settings)
t1 = tokenizer.encode(p1, encode_special_tokens = prompt_format.encode_special_tokens())
t2 = tokenizer.encode(p2, encode_special_tokens = prompt_format.encode_special_tokens())
system_length = t2.shape[-1] - t1.shape[-1]
# Format and tokenize prompt-response pairs without system prompt
pairs = []
tokenized_pairs = []
for turn in range(len(prompts)):
p = prompts[turn]
r = responses[turn] if turn < len(responses) else None
pair = prompt_format.format(p, r, None, self.settings)
pairs.append(pair)
tokenized_pairs.append(tokenizer.encode(pair, encode_special_tokens = prompt_format.encode_special_tokens()))
lengths = [tp.shape[-1] for tp in tokenized_pairs]
# Advance or roll back history
current_length = system_length + sum(lengths[self.history_first:])
if current_length > max_len:
target_max = min_len
while current_length > target_max and self.history_first < len(prompts) - 1:
current_length -= lengths[self.history_first]
self.history_first += 1
while current_length < min_len and self.history_first > 0:
if current_length + lengths[self.history_first - 1] > max_len: break
self.history_first -= 1
current_length += lengths[self.history_first]
# Reinsert system prompt at new first position
p = prompts[self.history_first]
r = responses[self.history_first] if self.history_first < len(responses) else None
pair = prompt_format.format(p, r, self.settings["system_prompt"], self.settings)
pairs[self.history_first] = pair
tokenized_pairs[self.history_first] = tokenizer.encode(pair, encode_special_tokens = prompt_format.encode_special_tokens())
# Create context
context_str = "".join(pairs[self.history_first:])
context_ids = torch.cat(tokenized_pairs[self.history_first:], dim = -1)
# print("self.history_first", self.history_first)
# print("context_ids.shape[-1]", context_ids.shape[-1])
return context_str, context_ids
def create_context_raw(self, prompt_format, max_len, min_len, prefix=""):
tokenizer = get_loaded_model().tokenizer
history_copy = [h["text"] for h in self.history]
# Get length of system prompt
if self.settings["system_prompt"] and self.settings["system_prompt"].strip() != "":
system_prompt = self.settings["system_prompt"] + "\n"
system_prompt_tokenized = tokenizer.encode(system_prompt, encode_special_tokens = prompt_format.encode_special_tokens())
system_length = system_prompt_tokenized.shape[-1]
else:
system_prompt = ""
system_prompt_tokenized = torch.empty((1, 0), dtype = torch.long)
system_length = 0
# Format and tokenize block without system prompt
blocks = []
tokenized_blocks = []
for turn in range(len(history_copy)):
block = history_copy[turn] + "\n"
blocks.append(block)
tokenized_blocks.append(tokenizer.encode(block, encode_special_tokens = prompt_format.encode_special_tokens()))
if prefix != "":
block = prefix
blocks.append(block)
tokenized_blocks.append(tokenizer.encode(block, encode_special_tokens = prompt_format.encode_special_tokens()))
lengths = [tp.shape[-1] for tp in tokenized_blocks]
# Advance or roll back history
current_length = system_length + sum(lengths[self.history_first:])
if current_length > max_len:
target_max = min_len
while current_length > target_max and self.history_first < len(history_copy) - 1:
current_length -= lengths[self.history_first]
self.history_first += 1
while current_length < min_len and self.history_first > 0:
if current_length + lengths[self.history_first - 1] > max_len: break
self.history_first -= 1
current_length += lengths[self.history_first]
# Create context
context_str = system_prompt + "".join(blocks[self.history_first:])
context_ids = torch.cat([system_prompt_tokenized] + tokenized_blocks[self.history_first:], dim = -1)
# print("self.history_first", self.history_first)
# print("context_ids.shape[-1]", context_ids.shape[-1])
return context_str, context_ids
def generate(self, data):
@@ -329,9 +411,11 @@ class Session:
last_chunk_time = time.time()
full_response = "" # TODO: Preload response
save_tokens = torch.empty((1, 0), dtype = torch.bool)
save_tokens = torch.empty((1, 0), dtype = torch.long)
chunk_buffer = ""
chunk_size = self.settings["chunktokens"]
# If not in instruct mode, generate bot name prefix
prefix = ""
@@ -342,8 +426,9 @@ class Session:
if r.strip() != "": bot_roles.append(r + ":")
assert len(bot_roles) >= 1
past_tokens = model.config.max_seq_len - self.settings["chunktokens"] - save_tokens.shape[-1]
context_str, context_ids = self.create_context(prompt_format, past_tokens)
past_tokens = model.config.max_seq_len - chunk_size - save_tokens.shape[-1]
past_tokens_min = model.config.max_seq_len - 2 * chunk_size - save_tokens.shape[-1]
context_str, context_ids = self.create_context(prompt_format, past_tokens, past_tokens_min)
gen_settings.filters = [ExLlamaV2SelectFilter(model, tokenizer, bot_roles, case_insensitive = False)]
mt.set_stage("prompt")
@@ -388,8 +473,9 @@ class Session:
packet["block_uuid"] = new_block["block_uuid"]
yield json.dumps(packet) + "\n"
past_tokens = model.config.max_seq_len - self.settings["chunktokens"] - save_tokens.shape[-1]
context_str, context_ids = self.create_context(prompt_format, past_tokens, prefix)
past_tokens = model.config.max_seq_len - chunk_size - save_tokens.shape[-1]
past_tokens_min = model.config.max_seq_len - 2 * chunk_size - save_tokens.shape[-1]
context_str, context_ids = self.create_context(prompt_format, past_tokens, past_tokens_min, prefix)
context_ids = torch.cat((context_ids, save_tokens), dim = -1)
mt.set_stage("prompt")

View File

@@ -72,8 +72,8 @@ export class SessionSettings {
this.sss_i_repRange = new controls.SettingsSlider("sss-item-left", "Rep. range", "sss-item-mid", "sss-item-right sss-item-textbox-r", 0, 0, 4096, { "0": "off" }, this.settings, "repr", () => { this.updateView(true); });
this.sss_i_mirostat = new controls.CheckboxLabel("sss-item-right clickable", "Mirostat", this.settings, "mirostat", () => { this.updateView(true); });
this.sss_i_mirostat_tau = new controls.SettingsSlider("sss-item-left", "Mirostat Tau", "sss-item-mid", "sss-item-right sss-item-textbox-r", 2, 0.01, 10, null, this.settings, "mirostat_tau", () => { this.updateView(true); });
this.sss_i_mirostat_eta = new controls.SettingsSlider("sss-item-left", "Mirostat Eta", "sss-item-mid", "sss-item-right sss-item-textbox-r", 2, 0.01, 5, null, this.settings, "mirostat_eta", () => { this.updateView(true); });
this.sss_i_mirostat_tau = new controls.SettingsSlider("sss-item-left", "Mirostat tau", "sss-item-mid", "sss-item-right sss-item-textbox-r", 2, 0.01, 10, null, this.settings, "mirostat_tau", () => { this.updateView(true); });
this.sss_i_mirostat_eta = new controls.SettingsSlider("sss-item-left", "Mirostat eta", "sss-item-mid", "sss-item-right sss-item-textbox-r", 2, 0.01, 5, null, this.settings, "mirostat_eta", () => { this.updateView(true); });
this.sss_sampling.inner.appendChild(this.sss_i_temperature.element);
this.sss_sampling.inner.appendChild(this.sss_i_topK.element);