mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-01 03:41:24 +00:00
Live2d Init
This commit is contained in:
43
live2d/tha3/compute/cached_computation_protocol.py
Normal file
43
live2d/tha3/compute/cached_computation_protocol.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc
|
||||
|
||||
|
||||
class CachedComputationProtocol(ABC):
|
||||
def get_output(self,
|
||||
key: str,
|
||||
modules: Dict[str, Module],
|
||||
batch: List[Tensor],
|
||||
outputs: Dict[str, List[Tensor]]):
|
||||
if key in outputs:
|
||||
return outputs[key]
|
||||
else:
|
||||
output = self.compute_output(key, modules, batch, outputs)
|
||||
outputs[key] = output
|
||||
return outputs[key]
|
||||
|
||||
@abstractmethod
|
||||
def compute_output(self,
|
||||
key: str,
|
||||
modules: Dict[str, Module],
|
||||
batch: List[Tensor],
|
||||
outputs: Dict[str, List[Tensor]]) -> List[Tensor]:
|
||||
pass
|
||||
|
||||
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc:
|
||||
def func(modules: Dict[str, Module],
|
||||
batch: List[Tensor],
|
||||
outputs: Dict[str, List[Tensor]]):
|
||||
return self.get_output(key, modules, batch, outputs)[index]
|
||||
return func
|
||||
|
||||
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc:
|
||||
def func(modules: Dict[str, Module],
|
||||
batch: List[Tensor],
|
||||
outputs: Dict[str, List[Tensor]]):
|
||||
return self.get_output(key, modules, batch, outputs)
|
||||
return func
|
||||
Reference in New Issue
Block a user