Files
SillyTavern-extras/talkinghead/tha3/compute/cached_computation_func.py
2023-08-11 06:50:59 +09:00

10 lines
329 B
Python

from typing import Callable, Dict, List
from torch import Tensor
from torch.nn import Module
TensorCachedComputationFunc = Callable[
[Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], Tensor]
TensorListCachedComputationFunc = Callable[
[Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], List[Tensor]]