mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-26 17:39:15 +00:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||
if cache_list is not None:
|
||||
cache_x = cache_list[cache_idx]
|
||||
cache_list[cache_idx] = None
|
||||
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
del cache_x
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
return x + self.shortcut(old_x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
x = layer(x, feat_cache[idx])
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
@@ -106,10 +106,12 @@ def model_sampling(model_config, model_type):
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
def convert_tensor(extra, dtype):
|
||||
def convert_tensor(extra, dtype, device):
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra = extra.to(dtype=dtype, device=device)
|
||||
else:
|
||||
extra = extra.to(device=device)
|
||||
return extra
|
||||
|
||||
|
||||
@@ -169,20 +171,21 @@ class BaseModel(torch.nn.Module):
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = context.to(dtype)
|
||||
context = context.to(dtype=dtype, device=device)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype)
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype))
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
|
||||
Reference in New Issue
Block a user