mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
201 lines
6.0 KiB
Python
201 lines
6.0 KiB
Python
import torch
|
|
import pandas, fastparquet
|
|
import os
|
|
from safetensors.torch import save_file
|
|
import random
|
|
|
|
def get_tokens(num_rows, length, filename, tokenizer):
|
|
|
|
min_tokens = num_rows * length
|
|
|
|
df = pandas.read_parquet(filename, engine = "fastparquet")
|
|
df['concatenated'] = df.apply(lambda r: ' '.join([str(v) for v in r.values]), axis = 1)
|
|
|
|
all_tokens = torch.empty((1,0), dtype = torch.long)
|
|
|
|
for _, row in df['concatenated'].items():
|
|
tokens = tokenizer.encode(row)
|
|
all_tokens = torch.cat((all_tokens, tokens), dim = -1)
|
|
if all_tokens.shape[-1] >= min_tokens: break
|
|
|
|
if all_tokens.shape[-1] < min_tokens:
|
|
print(f" ** Warning: Not enough sample data in {filename}")
|
|
|
|
all_tokens = all_tokens.flatten()[:min_tokens]
|
|
all_tokens = all_tokens.view((num_rows, length))
|
|
|
|
num_print_tokens = 50
|
|
data_sample = all_tokens[0, :num_print_tokens]
|
|
print(f" -- First {num_print_tokens} tokens of dataset:")
|
|
print(f" {repr(tokenizer.decode(data_sample))}")
|
|
data_sample = all_tokens[-1, -num_print_tokens:]
|
|
print(f" -- Last {num_print_tokens} tokens of dataset:")
|
|
print(f" {repr(tokenizer.decode(data_sample))}")
|
|
|
|
return all_tokens
|
|
|
|
|
|
def tokenize(job, save_fn, tokenizer, measure = False):
|
|
|
|
cal_ds = job["cal_dataset"]
|
|
|
|
if cal_ds is not None:
|
|
rows = job["measurement_rows"] if measure else job["dataset_rows"]
|
|
length = job["measurement_length"] if measure else job["length"]
|
|
cal_tokens = get_tokens(rows, length, cal_ds, tokenizer)
|
|
else:
|
|
cal_tokens = get_standard_calibration(measure, tokenizer)
|
|
|
|
cal_filename = os.path.join(job["out_dir"], "cal_data.safetensors")
|
|
cal_dict = { "input_ids": cal_tokens }
|
|
save_file(cal_dict, cal_filename)
|
|
job["cal_filename"] = cal_filename
|
|
|
|
|
|
def get_standard_calibration(measure, tokenizer):
|
|
|
|
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "standard_cal_data")
|
|
file_c4 =os.path.join(data_dir, "c4.utf8")
|
|
file_code =os.path.join(data_dir, "code.utf8")
|
|
file_multilingual =os.path.join(data_dir, "multilingual.utf8")
|
|
file_technical =os.path.join(data_dir, "technical.utf8")
|
|
file_wiki = os.path.join(data_dir, "wiki.utf8")
|
|
file_tiny = os.path.join(data_dir, "tiny.utf8")
|
|
|
|
rows = []
|
|
rows_c4 = 2 if measure else 10
|
|
rows_wiki = 4 if measure else 48
|
|
rows_code = 3 if measure else 15
|
|
rows_tiny = 2 if measure else 10
|
|
rows_multilingual = 3 if measure else 15
|
|
rows_multilingual_s = 1 if measure else 5
|
|
rows_technical = 2 if measure else 10
|
|
rows_random = 2
|
|
|
|
# C4: 10 rows
|
|
|
|
with open(file_c4, "r", encoding="utf8") as f:
|
|
lines = f.readlines()
|
|
|
|
text = "\n\n".join(lines)
|
|
tokens = tokenizer.encode(text)
|
|
tokens = tokens[:, : tokens.shape[-1] - (tokens.shape[-1] % 2048)]
|
|
tokenized_rows = tokens.view(-1, 2048)
|
|
|
|
for i in range(rows_c4):
|
|
rows.append(tokenized_rows[i:i+1])
|
|
|
|
# Wiki: 24 aligned rows + 24 aligned rows with BOS
|
|
|
|
with open(file_wiki, "r", encoding="utf8") as f:
|
|
text = f.read()
|
|
|
|
articles = [a[a.find("\n") + 1:] for a in text.split("</doc>\n")]
|
|
tokenized_articles = [tokenizer.encode(a, add_bos = True, add_eos = True) for a in articles]
|
|
|
|
idx = 0
|
|
for r in range(rows_wiki):
|
|
length = 0
|
|
idx0 = idx
|
|
while length < 2049:
|
|
length += tokenized_articles[idx].shape[-1]
|
|
idx += 1
|
|
row = torch.cat(tokenized_articles[idx0 : idx], dim = -1)
|
|
if r < rows_wiki // 2: row = row[:, 1:2049]
|
|
else: row = row[:, :2048]
|
|
rows.append(row)
|
|
|
|
# Code: 15 rows
|
|
|
|
with open(file_code, "r", encoding="utf8") as f:
|
|
text = f.read()
|
|
|
|
tokens = tokenizer.encode(text)
|
|
tokens = tokens[:, : tokens.shape[-1] - (tokens.shape[-1] % 2048)]
|
|
tokenized_rows = tokens.view(-1, 2048)
|
|
|
|
for i in range(rows_code):
|
|
rows.append(tokenized_rows[i:i+1])
|
|
|
|
# Tinystories: 5 aligned rows + 5 aligned rows with BOS
|
|
|
|
with open(file_tiny, "r", encoding="utf8") as f:
|
|
text = f.read()
|
|
|
|
articles = text.split("<|endoftext|>")
|
|
tokenized_articles = [tokenizer.encode(a.strip(), add_bos = True, add_eos = True) for a in articles]
|
|
|
|
idx = 0
|
|
for r in range(rows_tiny):
|
|
length = 0
|
|
idx0 = idx
|
|
while length < 2049:
|
|
length += tokenized_articles[idx].shape[-1]
|
|
idx += 1
|
|
row = torch.cat(tokenized_articles[idx0 : idx], dim = -1)
|
|
if r < rows_tiny // 2: row = row[:, 1:2049]
|
|
else: row = row[:, :2048]
|
|
rows.append(row)
|
|
|
|
# Multilingual: 15 rows + 5 shuffled rows
|
|
|
|
with open(file_multilingual, "r", encoding="utf8") as f:
|
|
text = f.read()
|
|
|
|
tokens = tokenizer.encode(text)
|
|
tokens = tokens[:, : tokens.shape[-1] - (tokens.shape[-1] % 2048)]
|
|
tokenized_rows = tokens.view(-1, 2048)
|
|
|
|
for i in range(rows_multilingual):
|
|
rows.append(tokenized_rows[i:i+1])
|
|
|
|
tokenized_rows = tokens.view(-1, 128)
|
|
random.seed(69420)
|
|
for i in range(rows_multilingual_s):
|
|
row = []
|
|
for j in range(2048 // 128):
|
|
k = random.randint(0, tokenized_rows.shape[0] - 1)
|
|
row.append(tokenized_rows[k].unsqueeze(0))
|
|
rows.append(torch.cat(row, dim = -1))
|
|
|
|
# Randomized: 2 rows
|
|
|
|
vocab_size = tokenizer.get_vocab_size()
|
|
random.seed(69420)
|
|
for i in range(rows_random):
|
|
row = torch.randint(0, vocab_size, (1, 2048), dtype = torch.long)
|
|
rows.append(row)
|
|
|
|
# Technical: 10 rows
|
|
|
|
with open(file_technical, "r", encoding="utf8") as f:
|
|
text = f.read()
|
|
|
|
tokens = tokenizer.encode(text)
|
|
tokens = tokens[:, : tokens.shape[-1] - (tokens.shape[-1] % 2048)]
|
|
tokenized_rows = tokens.view(-1, 2048)
|
|
|
|
for i in range(rows_technical):
|
|
rows.append(tokenized_rows[i:i+1])
|
|
|
|
# for idx, r in enumerate(rows):
|
|
# print("------------------------------------------------------------------------------")
|
|
# print(idx)
|
|
# print("--------")
|
|
# print(tokenizer.decode(r))
|
|
|
|
return torch.cat(rows, dim = 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|