mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-10 14:00:13 +00:00
43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, List
|
|
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
|
|
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc
|
|
|
|
|
|
class CachedComputationProtocol(ABC):
|
|
def get_output(self,
|
|
key: str,
|
|
modules: Dict[str, Module],
|
|
batch: List[Tensor],
|
|
outputs: Dict[str, List[Tensor]]):
|
|
if key in outputs:
|
|
return outputs[key]
|
|
else:
|
|
output = self.compute_output(key, modules, batch, outputs)
|
|
outputs[key] = output
|
|
return outputs[key]
|
|
|
|
@abstractmethod
|
|
def compute_output(self,
|
|
key: str,
|
|
modules: Dict[str, Module],
|
|
batch: List[Tensor],
|
|
outputs: Dict[str, List[Tensor]]) -> List[Tensor]:
|
|
pass
|
|
|
|
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc:
|
|
def func(modules: Dict[str, Module],
|
|
batch: List[Tensor],
|
|
outputs: Dict[str, List[Tensor]]):
|
|
return self.get_output(key, modules, batch, outputs)[index]
|
|
return func
|
|
|
|
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc:
|
|
def func(modules: Dict[str, Module],
|
|
batch: List[Tensor],
|
|
outputs: Dict[str, List[Tensor]]):
|
|
return self.get_output(key, modules, batch, outputs)
|
|
return func |