mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-05 11:30:13 +00:00
188 lines
6.8 KiB
Cython
188 lines
6.8 KiB
Cython
# cython: language_level=3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import torch
|
|
from itertools import chain
|
|
from libc.math cimport ceil
|
|
|
|
cimport cython
|
|
cimport numpy as np
|
|
|
|
from libc.stdint cimport int32_t, int64_t
|
|
|
|
DTYPE = np.int64
|
|
ctypedef int64_t DTYPE_t
|
|
|
|
|
|
@cython.boundscheck(False)
|
|
@cython.wraparound(False)
|
|
@cython.nonecheck(False)
|
|
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
|
|
cdef DTYPE_t total_size = sizes.sum()
|
|
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
|
|
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
|
|
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
|
|
cdef DTYPE_t i
|
|
cdef DTYPE_t start
|
|
cdef DTYPE_t end
|
|
for i in range(length):
|
|
start = i * block_size
|
|
end = min(start + block_size, total_size)
|
|
slice_indices_view[i][0] = start
|
|
slice_indices_view[i][1] = end
|
|
return slice_indices
|
|
|
|
|
|
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
|
|
"""
|
|
Faster function to convert DTYPE_t list of list.
|
|
Only fast when there are huge number of rows and low number of columns.
|
|
"""
|
|
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
|
|
return flat.reshape((len(list_of_list), -1))
|
|
|
|
|
|
@cython.boundscheck(False)
|
|
@cython.wraparound(False)
|
|
@cython.nonecheck(False)
|
|
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
|
|
cdef DTYPE_t tok_idx = 0
|
|
cdef DTYPE_t sz_idx = 0
|
|
cdef DTYPE_t curr_size = 0
|
|
cdef DTYPE_t i = 0
|
|
cdef DTYPE_t length
|
|
cdef DTYPE_t total_size
|
|
cdef DTYPE_t[:] sizes_view = sizes
|
|
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
|
|
cdef list slice_indices_list = []
|
|
|
|
if break_mode is None or break_mode == 'none':
|
|
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
|
|
elif break_mode == 'complete':
|
|
while sz_idx < len(sizes_view):
|
|
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
|
|
curr_size += sizes_view[sz_idx]
|
|
sz_idx += 1
|
|
else:
|
|
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
|
tok_idx += curr_size
|
|
curr_size = 0
|
|
if curr_size > 0:
|
|
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
|
slice_indices = _fast_convert_to_np_array(slice_indices_list)
|
|
elif break_mode == 'complete_doc':
|
|
while sz_idx < len(sizes_view):
|
|
if (
|
|
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
|
|
# an empty sentence indicates end-of-document:
|
|
and sizes_view[sz_idx] != document_sep_len
|
|
):
|
|
curr_size += sizes_view[sz_idx]
|
|
sz_idx += 1
|
|
else:
|
|
# Only keep non-empty documents.
|
|
if curr_size > 1:
|
|
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
|
tok_idx += curr_size
|
|
curr_size = 0
|
|
if sizes_view[sz_idx] == document_sep_len:
|
|
tok_idx += sizes_view[sz_idx]
|
|
sz_idx += 1
|
|
if curr_size > 1:
|
|
slice_indices_list.append((tok_idx, tok_idx + curr_size))
|
|
slice_indices = _fast_convert_to_np_array(slice_indices_list)
|
|
elif break_mode == 'eos':
|
|
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
|
|
cumsum = sizes.cumsum(axis=0)
|
|
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
|
|
slice_indices[:, 1] = cumsum
|
|
else:
|
|
raise ValueError('Invalid break_mode: ' + break_mode)
|
|
return slice_indices
|
|
|
|
|
|
@cython.boundscheck(False)
|
|
@cython.wraparound(False)
|
|
@cython.nonecheck(False)
|
|
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
|
|
cdef DTYPE_t start_ds_idx
|
|
cdef DTYPE_t start_offset
|
|
cdef DTYPE_t end_ds_idx
|
|
cdef DTYPE_t i
|
|
cdef DTYPE_t s
|
|
cdef DTYPE_t e
|
|
cdef DatasetSearcher ds = DatasetSearcher(sizes)
|
|
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
|
|
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
|
|
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
|
|
cdef Py_ssize_t x_max = slice_indices.shape[0]
|
|
|
|
for i in range(x_max):
|
|
s = slice_indices_view[i][0]
|
|
e = slice_indices_view[i][1]
|
|
ds.seek(s)
|
|
start_ds_idx = ds.current_index
|
|
start_offset = ds.current_offset
|
|
if e <= s:
|
|
end_ds_idx = start_ds_idx
|
|
else:
|
|
ds.seek(e - 1)
|
|
end_ds_idx = ds.current_index
|
|
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
|
|
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
|
|
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
|
|
return block_to_dataset_index
|
|
|
|
|
|
cdef class DatasetSearcher(object):
|
|
"""Helper for mapping "flat" indices to indices and offsets in an
|
|
underlying dataset."""
|
|
cdef DTYPE_t current_i
|
|
cdef DTYPE_t current_offset
|
|
cdef DTYPE_t current_index
|
|
cdef DTYPE_t[:] sizes
|
|
|
|
def __init__(self, DTYPE_t[:] sizes):
|
|
self.sizes = sizes
|
|
self.reset()
|
|
|
|
cdef reset(self):
|
|
self.current_offset = 0 # offset within current index in underlying dataset
|
|
self.current_i = 0 # "flat" index
|
|
self.current_index = 0 # index in underlying dataset
|
|
|
|
@cython.boundscheck(False)
|
|
@cython.wraparound(False)
|
|
@cython.nonecheck(False)
|
|
cdef int step(self, DTYPE_t i):
|
|
cdef DTYPE_t to_consume
|
|
cdef DTYPE_t remaining
|
|
if i < self.current_i:
|
|
self.reset()
|
|
if i > self.current_i:
|
|
to_consume = i - self.current_i
|
|
remaining = self.sizes[self.current_index] - self.current_offset
|
|
if remaining > to_consume:
|
|
self.current_offset += to_consume
|
|
self.current_i += to_consume
|
|
else:
|
|
assert remaining >= 0
|
|
self.current_i += remaining
|
|
self.current_index += 1
|
|
self.current_offset = 0
|
|
return 1
|
|
return 0
|
|
|
|
@cython.boundscheck(False)
|
|
@cython.wraparound(False)
|
|
@cython.nonecheck(False)
|
|
cdef seek(self, DTYPE_t i):
|
|
cdef int not_done = 1
|
|
while not_done == 1:
|
|
not_done = self.step(i)
|
|
assert self.current_i == i
|