Files
exllamav2/conversion/tokenize.py
2023-09-20 10:03:11 +02:00

50 lines
1.6 KiB
Python

import torch
import pandas, fastparquet
import os
from safetensors.torch import save_file
def get_tokens(num_rows, length, filename, tokenizer):
min_tokens = num_rows * length
df = pandas.read_parquet(filename, engine = "fastparquet")
df['concatenated'] = df.apply(lambda r: ' '.join([str(v) for v in r.values]), axis = 1)
all_tokens = torch.empty((1,0), dtype = torch.long)
for _, row in df['concatenated'].items():
tokens = tokenizer.encode(row)
all_tokens = torch.cat((all_tokens, tokens), dim = -1)
if all_tokens.shape[-1] >= min_tokens: break
if all_tokens.shape[-1] < min_tokens:
print(f" ** Warning: Not enough sample data in {filename}")
all_tokens = all_tokens.flatten()[:min_tokens]
all_tokens = all_tokens.view((num_rows, length))
num_print_tokens = 50
data_sample = all_tokens[0, :num_print_tokens]
print(f" -- First {num_print_tokens} tokens of dataset:")
print(f" {repr(tokenizer.decode(data_sample))}")
data_sample = all_tokens[-1, -num_print_tokens:]
print(f" -- Last {num_print_tokens} tokens of dataset:")
print(f" {repr(tokenizer.decode(data_sample))}")
return all_tokens
def tokenize(job, save_fn, tokenizer, measure = False):
cal_ds = job["cal_dataset"]
rows = job["measurement_rows"] if measure else job["dataset_rows"]
length = job["measurement_length"] if measure else job["length"]
cal_tokens = get_tokens(rows, length, cal_ds, tokenizer)
cal_filename = os.path.join(job["out_dir"], "cal_data.safetensors")
cal_dict = { "input_ids": cal_tokens }
save_file(cal_dict, cal_filename)
job["cal_filename"] = cal_filename