mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 09:18:59 +00:00
rework sample function
This commit is contained in:
136
backend/sampling/condition.py
Normal file
136
backend/sampling/condition.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size):
|
||||
if tensor.shape[0] > batch_size:
|
||||
return tensor[:batch_size]
|
||||
elif tensor.shape[0] < batch_size:
|
||||
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
|
||||
return tensor
|
||||
|
||||
|
||||
def lcm(a, b):
|
||||
return abs(a * b) // math.gcd(a, b)
|
||||
|
||||
|
||||
class Condition:
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def _copy_with(self, cond):
|
||||
return self.__class__(cond)
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(repeat_to_batch_size(self.cond, batch_size).to(device))
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
for x in others:
|
||||
conds.append(x.cond)
|
||||
return torch.cat(conds)
|
||||
|
||||
|
||||
class ConditionNoiseShape(Condition):
|
||||
def process_cond(self, batch_size, device, area, **kwargs):
|
||||
data = self.cond[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]]
|
||||
return self._copy_with(repeat_to_batch_size(data, batch_size).to(device))
|
||||
|
||||
|
||||
class ConditionCrossAttn(Condition):
|
||||
def can_concat(self, other):
|
||||
s1 = self.cond.shape
|
||||
s2 = other.cond.shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]:
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
conds = [self.cond]
|
||||
crossattn_max_len = self.cond.shape[1]
|
||||
for x in others:
|
||||
c = x.cond
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
conds.append(c)
|
||||
|
||||
out = []
|
||||
for c in conds:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1)
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
|
||||
class ConditionConstant(Condition):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
||||
|
||||
def compile_conditions(cond):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
result = dict(
|
||||
cross_attn=cond,
|
||||
model_conds=dict(
|
||||
c_crossattn=ConditionCrossAttn(cond),
|
||||
)
|
||||
)
|
||||
return [result, ]
|
||||
|
||||
cross_attn = cond['crossattn']
|
||||
pooled_output = cond['vector']
|
||||
|
||||
result = dict(
|
||||
cross_attn=cross_attn,
|
||||
pooled_output=pooled_output,
|
||||
model_conds=dict(
|
||||
c_crossattn=ConditionCrossAttn(cross_attn),
|
||||
y=Condition(pooled_output)
|
||||
)
|
||||
)
|
||||
|
||||
return [result, ]
|
||||
|
||||
|
||||
def compile_weighted_conditions(cond, weights):
|
||||
transposed = list(map(list, zip(*weights)))
|
||||
results = []
|
||||
|
||||
for cond_pre in transposed:
|
||||
current_indices = []
|
||||
current_weight = 0
|
||||
for i, w in cond_pre:
|
||||
current_indices.append(i)
|
||||
current_weight = w
|
||||
|
||||
if hasattr(cond, 'advanced_indexing'):
|
||||
feed = cond.advanced_indexing(current_indices)
|
||||
else:
|
||||
feed = cond[current_indices]
|
||||
|
||||
h = compile_conditions(feed)
|
||||
h[0]['strength'] = current_weight
|
||||
results += h
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user