mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 22:10:22 +00:00
22 lines
680 B
Python
22 lines
680 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
|
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
|
if x is None:
|
|
return None, 0
|
|
tsz = x.size(dim)
|
|
m = tsz / multiple
|
|
remainder = math.ceil(m) * multiple - tsz
|
|
if m.is_integer():
|
|
return x, 0
|
|
pad_offset = (0,) * (-1 - dim) * 2
|
|
|
|
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|