This commit is contained in:
lllyasviel
2024-01-27 18:53:24 -08:00
parent a4d743e5f8
commit 20f1fb6c0b
2 changed files with 16 additions and 4 deletions

View File

@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class DictWithShape(dict):
def __init__(self, x, shape):
def __init__(self, x):
super().__init__()
self.update(x)
@@ -282,6 +282,13 @@ class DictWithShape(dict):
self[k] = self[k].to(*args, **kwargs)
return self
def advanced_indexing(self, item):
result = {}
for k in self.keys():
if isinstance(self[k], torch.Tensor):
result[k] = self[k][item]
return DictWithShape(result)
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
@@ -290,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
if is_dict:
dict_cond = param
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
res = DictWithShape(res)
else:
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
@@ -348,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
if isinstance(tensors[0], dict):
keys = list(tensors[0].keys())
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
stacked = DictWithShape(stacked)
else:
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)