forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions

View File

@@ -56,8 +56,6 @@ parser.add_argument("--cuda-malloc", action="store_true")
parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
parser.add_argument("--i-am-lllyasviel", action="store_true")
args = parser.parse_known_args()[0]
# Some dynamic args that may be changed by webui rather than cmd flags.

View File

@@ -7,7 +7,7 @@ from backend.patcher.vae import VAE
from backend.patcher.unet import UnetPatcher
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.text_processing.t5_engine import T5TextProcessingEngine
from backend.args import dynamic_args, args
from backend.args import dynamic_args
from backend.modules.k_prediction import PredictionFlux
from backend import memory_management
@@ -16,9 +16,6 @@ class Flux(ForgeDiffusionEngine):
matched_guesses = [model_list.Flux]
def __init__(self, estimated_config, huggingface_components):
if not args.i_am_lllyasviel:
raise NotImplementedError('Flux is not implemented yet!')
super().__init__(estimated_config, huggingface_components)
self.is_inpaint = False
@@ -68,9 +65,6 @@ class Flux(ForgeDiffusionEngine):
self.use_distilled_cfg_scale = True
# WebUI Legacy
self.first_stage_model = vae.first_stage_model
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip
@@ -93,7 +87,7 @@ class Flux(ForgeDiffusionEngine):
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_t5.process_texts([prompt])
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
return token_count, max(255, token_count)
@torch.inference_mode()

View File

@@ -53,7 +53,6 @@ class StableDiffusion(ForgeDiffusionEngine):
# WebUI Legacy
self.is_sd1 = True
self.first_stage_model = vae.first_stage_model
def set_clip_skip(self, clip_skip):
self.text_processing_engine.clip_skip = clip_skip

View File

@@ -53,7 +53,6 @@ class StableDiffusion2(ForgeDiffusionEngine):
# WebUI Legacy
self.is_sd2 = True
self.first_stage_model = vae.first_stage_model
def set_clip_skip(self, clip_skip):
self.text_processing_engine.clip_skip = clip_skip

View File

@@ -72,7 +72,6 @@ class StableDiffusionXL(ForgeDiffusionEngine):
# WebUI Legacy
self.is_sdxl = True
self.first_stage_model = vae.first_stage_model
def set_clip_skip(self, clip_skip):
self.text_processing_engine_l.clip_skip = clip_skip

View File

@@ -3,6 +3,7 @@ import torch
import logging
import importlib
import backend.args
import huggingface_guess
from diffusers import DiffusionPipeline
@@ -69,9 +70,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
config = read_arbitrary_config(config_path)
dtype = memory_management.text_encoder_dtype()
sd_dtype = state_dict['transformer.encoder.block.0.layer.0.SelfAttention.k.weight'].dtype
sd_dtype = memory_management.state_dict_dtype(state_dict)
if sd_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(f'Using Detected T5 Data Type: {sd_dtype}')
dtype = sd_dtype
with modeling_utils.no_init_weights():
@@ -81,32 +83,60 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight'])
return model
if cls_name == 'UNet2DConditionModel':
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
model_loader = None
if cls_name == 'UNet2DConditionModel':
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
to_args = dict(device=ini_device, dtype=ini_dtype)
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
with using_forge_operations(**to_args):
model = IntegratedUNet2DConditionModel.from_config(unet_config).to(**to_args)
model._internal_dict = unet_config
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
if unet_storage_dtype_overwrite is not None:
storage_dtype = unet_storage_dtype_overwrite
else:
if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']:
print(f'Using Detected UNet Type: {state_dict_dtype}')
storage_dtype = state_dict_dtype
if state_dict_dtype in ['nf4', 'fp4']:
print(f'Using pre-quant state dict!')
load_device = memory_management.get_torch_device()
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
offload_device = memory_management.unet_offload_device()
if storage_dtype in ['nf4', 'fp4']:
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
model = model_loader(unet_config)
else:
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
need_manual_cast = storage_dtype != computation_dtype
to_args = dict(device=initial_device, dtype=storage_dtype)
with using_forge_operations(**to_args, manual_cast_enabled=need_manual_cast):
model = model_loader(unet_config).to(**to_args)
load_state_dict(model, state_dict)
return model
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
unet_config = guess.unet_config.copy()
state_dict_size = memory_management.state_dict_size(state_dict)
ini_dtype = memory_management.unet_dtype(model_params=state_dict_size)
ini_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=ini_dtype)
to_args = dict(device=ini_device, dtype=ini_dtype)
with using_forge_operations(**to_args):
model = IntegratedFluxTransformer2DModel(**unet_config).to(**to_args)
if hasattr(model, '_internal_dict'):
model._internal_dict = unet_config
else:
model.config = unet_config
load_state_dict(model, state_dict)
model.storage_dtype = storage_dtype
model.computation_dtype = computation_dtype
model.load_device = load_device
model.initial_device = initial_device
model.offload_device = offload_device
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')

View File

@@ -8,7 +8,7 @@ import platform
from enum import Enum
from backend import stream
from backend.args import args, dynamic_args
from backend.args import args
cpu = torch.device('cpu')
@@ -281,12 +281,8 @@ except:
print("Could not pick default device.")
if 'rtx' in torch_device_name.lower():
if not args.pin_shared_memory:
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
if not args.cuda_malloc:
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
if not args.cuda_stream:
print('Hint: your device supports --cuda-stream for potential speed improvements.')
current_loaded_models = []
@@ -305,8 +301,54 @@ def state_dict_size(sd, exclude_device=None):
return module_mem
def state_dict_dtype(state_dict):
for k in state_dict.keys():
if 'bitsandbytes__nf4' in k:
return 'nf4'
if 'bitsandbytes__fp4' in k:
return 'fp4'
dtype_counts = {}
for tensor in state_dict.values():
dtype = tensor.dtype
if dtype in dtype_counts:
dtype_counts[dtype] += 1
else:
dtype_counts[dtype] = 1
major_dtype = None
max_count = 0
for dtype, count in dtype_counts.items():
if count > max_count:
max_count = count
major_dtype = dtype
return major_dtype
def module_size(module, exclude_device=None):
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
module_mem = 0
for p in module.parameters():
t = p.data
if exclude_device is not None:
if t.device == exclude_device:
continue
element_size = t.element_size()
if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
if element_size > 1:
# not quanted yet
element_size = 0.55 # a bit more than 0.5 because of quant state parameters
else:
# quanted
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
module_mem += t.nelement() * element_size
return module_mem
class LoadedModel:
@@ -587,11 +629,6 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
if unet_storage_dtype_overwrite is not None:
return unet_storage_dtype_overwrite
if args.unet_in_bf16:
return torch.bfloat16
@@ -1040,6 +1077,18 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def can_install_bnb():
if not torch.cuda.is_available():
return False
cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
if cuda_version >= (11, 7):
return True
return False
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@@ -5,16 +5,13 @@ from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
def __init__(self, model, diffusers_scheduler, storage_dtype, computation_dtype, k_predictor=None):
def __init__(self, model, diffusers_scheduler, k_predictor=None):
super().__init__()
self.storage_dtype = storage_dtype
self.computation_dtype = computation_dtype
self.storage_dtype = model.storage_dtype
self.computation_dtype = model.computation_dtype
need_manual_cast = self.storage_dtype != self.computation_dtype
operations.shift_manual_cast(model, enabled=need_manual_cast)
print(f'K-Model Created: {dict(storage_dtype=storage_dtype, computation_dtype=computation_dtype, manual_cast=need_manual_cast)}')
print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}')
self.diffusion_model = model

View File

@@ -10,4 +10,3 @@ class IntegratedCLIP(torch.nn.Module):
if add_text_projection:
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))

View File

@@ -433,9 +433,9 @@ class ResBlock(TimestepBlock):
def _forward(self, x, emb, transformer_options={}):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
in_norm, in_rest = in_rest[0], in_rest[1:]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
h = in_rest(h)
else:
h = in_rest(x)
@@ -443,9 +443,9 @@ class ResBlock(TimestepBlock):
x = self.x_upd(x)
h = in_conv(h)
else:
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
in_norm = self.in_layers[0]
h = transformer_options["groupnorm_wrapper"](in_norm, x, transformer_options)
h = transformer_options["group_norm_wrapper"](in_norm, x, transformer_options)
h = self.in_layers[1:](h)
else:
h = self.in_layers(x)
@@ -456,8 +456,8 @@ class ResBlock(TimestepBlock):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
if "group_norm_wrapper" in transformer_options:
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
else:
h = out_norm(h)
if emb_out is not None:
@@ -470,8 +470,8 @@ class ResBlock(TimestepBlock):
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
if "groupnorm_wrapper" in transformer_options:
h = transformer_options["groupnorm_wrapper"](self.out_layers[0], h, transformer_options)
if "group_norm_wrapper" in transformer_options:
h = transformer_options["group_norm_wrapper"](self.out_layers[0], h, transformer_options)
h = self.out_layers[1:](h)
else:
h = self.out_layers(h)
@@ -752,9 +752,9 @@ class IntegratedUNet2DConditionModel(nn.Module, ConfigMixin):
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if "groupnorm_wrapper" in transformer_options:
if "group_norm_wrapper" in transformer_options:
out_norm, out_rest = self.out[0], self.out[1:]
h = transformer_options["groupnorm_wrapper"](out_norm, h, transformer_options)
h = transformer_options["group_norm_wrapper"](out_norm, h, transformer_options)
h = out_rest(h)
else:
h = self.out(h)

View File

@@ -1,3 +1,5 @@
# Copyright Forge 2024
import time
import torch
import contextlib
@@ -8,7 +10,7 @@ from backend import stream, memory_management
stash = {}
def weights_manual_cast(layer, x, skip_dtype=False):
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False):
weight, bias, signal = None, None, None
non_blocking = True
@@ -18,21 +20,28 @@ def weights_manual_cast(layer, x, skip_dtype=False):
target_dtype = x.dtype
target_device = x.device
if skip_dtype:
target_dtype = None
if skip_weight_dtype:
weight_args = dict(device=target_device, non_blocking=non_blocking)
else:
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if skip_bias_dtype:
bias_args = dict(device=target_device, non_blocking=non_blocking)
else:
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
bias = layer.bias.to(**bias_args)
signal = stream.mover_stream.record_event()
else:
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
weight = layer.weight.to(**weight_args)
if layer.bias is not None:
bias = layer.bias.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
bias = layer.bias.to(**bias_args)
return weight, bias, signal
@@ -72,19 +81,27 @@ def cleanup_cache():
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class ForgeOperations:
class Linear(torch.nn.Linear):
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
if self.parameters_manual_cast:
@@ -92,7 +109,7 @@ class ForgeOperations:
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
return super().forward(x)
return torch.nn.functional.linear(x, self.weight, self.bias)
class Conv2d(torch.nn.Conv2d):
@@ -269,21 +286,61 @@ class ForgeOperations:
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_dtype=True)
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
return super().forward(x)
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False):
global current_device, current_dtype, current_manual_cast_enabled
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits
current_device, current_dtype, current_manual_cast_enabled = device, dtype, manual_cast_enabled
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, **kwargs):
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)
bnb_avaliable = True
except:
bnb_avaliable = False
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
if operations is None:
operations = ForgeOperations
if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
operations = ForgeOperationsBNB4bits
else:
operations = ForgeOperations
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}

File diff suppressed because it is too large Load Diff

View File

@@ -428,11 +428,17 @@ class ControlLora(ControlNet):
controlnet_config = model.diffusion_model.config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["dtype"] = dtype = model.storage_dtype
dtype = model.storage_dtype
if dtype in ['nf4', 'fp4']:
dtype = torch.float16
controlnet_config["dtype"] = dtype
self.manual_cast_dtype = model.computation_dtype
with using_forge_operations(operations=ControlLoraOps):
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)

View File

@@ -3,20 +3,18 @@ import torch
from backend.modules.k_model import KModel
from backend.patcher.base import ModelPatcher
from backend import memory_management
class UnetPatcher(ModelPatcher):
@classmethod
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
parameters = memory_management.module_size(model)
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes)
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype)
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
return UnetPatcher(
model,
load_device=model.diffusion_model.load_device,
offload_device=model.diffusion_model.offload_device,
current_device=model.diffusion_model.initial_device
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -169,8 +167,8 @@ class UnetPatcher(ModelPatcher):
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
return
def set_groupnorm_wrapper(self, wrapper):
self.set_transformer_option('groupnorm_wrapper', wrapper)
def set_group_norm_wrapper(self, wrapper):
self.set_transformer_option('group_norm_wrapper', wrapper)
return
def set_controlnet_model_function_wrapper(self, wrapper):

View File

@@ -1,9 +1,7 @@
import math
import torch
from collections import namedtuple
from backend.text_processing import parsing, emphasis
from backend.text_processing.textual_inversion import EmbeddingDatabase
from backend import memory_management
@@ -50,9 +48,6 @@ class T5TextProcessingEngine:
if mult != 1.0:
self.token_mults[ident] = mult
def get_target_prompt_token_count(self, token_count):
return token_count
def tokenize(self, texts):
tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
@@ -112,45 +107,33 @@ class T5TextProcessingEngine:
return chunks, token_count
def process_texts(self, texts):
token_count = 0
def __call__(self, texts):
zs = []
cache = {}
batch_chunks = []
for line in texts:
if line in cache:
chunks = cache[line]
line_z_values = cache[line]
else:
chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
chunks, token_count = self.tokenize_line(line)
line_z_values = []
for chunk in chunks:
tokens = chunk.tokens
multipliers = chunk.multipliers
z = self.process_tokens([tokens], [multipliers])[0]
line_z_values.append(z)
cache[line] = line_z_values
cache[line] = chunks
zs.extend(line_z_values)
batch_chunks.append(chunks)
return torch.stack(zs)
return batch_chunks, token_count
def __call__(self, texts):
batch_chunks, token_count = self.process_texts(texts)
chunk_count = max([len(x) for x in batch_chunks])
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] for chunks in batch_chunks]
tokens = [x.tokens for x in batch_chunk]
multipliers = [x.multipliers for x in batch_chunk]
z = self.process_tokens(tokens, multipliers)
zs.append(z)
return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
tokens = torch.asarray(remade_batch_tokens)
def process_tokens(self, batch_tokens, batch_multipliers):
tokens = torch.asarray(batch_tokens)
z = self.encode_with_transformers(tokens)
self.emphasis.tokens = remade_batch_tokens
self.emphasis.tokens = batch_tokens
self.emphasis.multipliers = torch.asarray(batch_multipliers).to(z)
self.emphasis.z = z
self.emphasis.after_transformers()