143 lines
3.7 KiB
Python
Executable File
143 lines
3.7 KiB
Python
Executable File
import torch
|
|
import math
|
|
|
|
|
|
def repeat_to_batch_size(tensor, batch_size):
|
|
if tensor.shape[0] > batch_size:
|
|
return tensor[:batch_size]
|
|
elif tensor.shape[0] < batch_size:
|
|
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
|
return tensor
|
|
|
|
|
|
def lcm(a, b):
|
|
return abs(a * b) // math.gcd(a, b)
|
|
|
|
|
|
class Condition:
|
|
def __init__(self, cond):
|
|
self.cond = cond
|
|
|
|
def _copy_with(self, cond):
|
|
return self.__class__(cond)
|
|
|
|
def process_cond(self, batch_size, device, **kwargs):
|
|
return self._copy_with(repeat_to_batch_size(self.cond, batch_size).to(device))
|
|
|
|
def can_concat(self, other):
|
|
if self.cond.shape != other.cond.shape:
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
conds = [self.cond]
|
|
for x in others:
|
|
conds.append(x.cond)
|
|
return torch.cat(conds)
|
|
|
|
|
|
class ConditionNoiseShape(Condition):
|
|
def process_cond(self, batch_size, device, area, **kwargs):
|
|
data = self.cond[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]]
|
|
return self._copy_with(repeat_to_batch_size(data, batch_size).to(device))
|
|
|
|
|
|
class ConditionCrossAttn(Condition):
|
|
def can_concat(self, other):
|
|
s1 = self.cond.shape
|
|
s2 = other.cond.shape
|
|
if s1 != s2:
|
|
if s1[0] != s2[0] or s1[2] != s2[2]:
|
|
return False
|
|
|
|
mult_min = lcm(s1[1], s2[1])
|
|
diff = mult_min // min(s1[1], s2[1])
|
|
if diff > 4:
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
conds = [self.cond]
|
|
crossattn_max_len = self.cond.shape[1]
|
|
for x in others:
|
|
c = x.cond
|
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
|
conds.append(c)
|
|
|
|
out = []
|
|
for c in conds:
|
|
if c.shape[1] < crossattn_max_len:
|
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1)
|
|
out.append(c)
|
|
return torch.cat(out)
|
|
|
|
|
|
class ConditionConstant(Condition):
|
|
def __init__(self, cond):
|
|
self.cond = cond
|
|
|
|
def process_cond(self, batch_size, device, **kwargs):
|
|
return self._copy_with(self.cond)
|
|
|
|
def can_concat(self, other):
|
|
if self.cond != other.cond:
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
return self.cond
|
|
|
|
|
|
def compile_conditions(cond):
|
|
if cond is None:
|
|
return None
|
|
|
|
if isinstance(cond, torch.Tensor):
|
|
result = dict(
|
|
cross_attn=cond,
|
|
model_conds=dict(
|
|
c_crossattn=ConditionCrossAttn(cond),
|
|
)
|
|
)
|
|
return [result, ]
|
|
|
|
cross_attn = cond['crossattn']
|
|
pooled_output = cond['vector']
|
|
|
|
result = dict(
|
|
cross_attn=cross_attn,
|
|
pooled_output=pooled_output,
|
|
model_conds=dict(
|
|
c_crossattn=ConditionCrossAttn(cross_attn),
|
|
y=Condition(pooled_output)
|
|
)
|
|
)
|
|
|
|
if 'guidance' in cond:
|
|
result['model_conds']['guidance'] = Condition(cond['guidance'])
|
|
|
|
return [result, ]
|
|
|
|
|
|
def compile_weighted_conditions(cond, weights):
|
|
transposed = list(map(list, zip(*weights)))
|
|
results = []
|
|
|
|
for cond_pre in transposed:
|
|
current_indices = []
|
|
current_weight = 0
|
|
for i, w in cond_pre:
|
|
current_indices.append(i)
|
|
current_weight = w
|
|
|
|
if hasattr(cond, 'advanced_indexing'):
|
|
feed = cond.advanced_indexing(current_indices)
|
|
else:
|
|
feed = cond[current_indices]
|
|
|
|
h = compile_conditions(feed)
|
|
h[0]['strength'] = current_weight
|
|
results += h
|
|
|
|
return results
|