mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-30 21:09:46 +00:00
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
# 1st edit by https://github.com/comfyanonymous/ComfyUI
|
|
# 2nd edit by Forge Official
|
|
|
|
|
|
import torch
|
|
import math
|
|
import ldm_patched.modules.utils
|
|
|
|
|
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|
return abs(a*b) // math.gcd(a, b)
|
|
|
|
class CONDRegular:
|
|
def __init__(self, cond):
|
|
self.cond = cond
|
|
|
|
def _copy_with(self, cond):
|
|
return self.__class__(cond)
|
|
|
|
def process_cond(self, batch_size, device, **kwargs):
|
|
return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
|
|
|
def can_concat(self, other):
|
|
if self.cond.shape != other.cond.shape:
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
conds = [self.cond]
|
|
for x in others:
|
|
conds.append(x.cond)
|
|
return torch.cat(conds)
|
|
|
|
class CONDNoiseShape(CONDRegular):
|
|
def process_cond(self, batch_size, device, area, **kwargs):
|
|
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
|
return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(data, batch_size).to(device))
|
|
|
|
|
|
class CONDCrossAttn(CONDRegular):
|
|
def can_concat(self, other):
|
|
s1 = self.cond.shape
|
|
s2 = other.cond.shape
|
|
if s1 != s2:
|
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
|
return False
|
|
|
|
mult_min = lcm(s1[1], s2[1])
|
|
diff = mult_min // min(s1[1], s2[1])
|
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
conds = [self.cond]
|
|
crossattn_max_len = self.cond.shape[1]
|
|
for x in others:
|
|
c = x.cond
|
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
|
conds.append(c)
|
|
|
|
out = []
|
|
for c in conds:
|
|
if c.shape[1] < crossattn_max_len:
|
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
|
out.append(c)
|
|
return torch.cat(out)
|
|
|
|
class CONDConstant(CONDRegular):
|
|
def __init__(self, cond):
|
|
self.cond = cond
|
|
|
|
def process_cond(self, batch_size, device, **kwargs):
|
|
return self._copy_with(self.cond)
|
|
|
|
def can_concat(self, other):
|
|
if self.cond != other.cond:
|
|
return False
|
|
return True
|
|
|
|
def concat(self, others):
|
|
return self.cond
|