Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data
2025-08-04 07:25:20 +09:00
9 changed files with 40 additions and 20 deletions

View File

@@ -43,7 +43,6 @@ if TYPE_CHECKING:
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor

View File

@@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
if cache_list is not None:
cache_x = cache_list[cache_idx]
cache_list[cache_idx] = None
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
del cache_x
x = F.pad(x, padding)
return super().forward(x)
@@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
@@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
return x + self.shortcut(old_x)
class AttentionBlock(nn.Module):

View File

@@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
],
dim=2,
)
x = layer(x, feat_cache[idx])
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:

View File

@@ -106,10 +106,12 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
def convert_tensor(extra, dtype, device):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = extra.to(dtype=dtype, device=device)
else:
extra = extra.to(device=device)
return extra
@@ -169,20 +171,21 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
device = xc.device
t = self.model_sampling.timestep(t).float()
if context is not None:
context = context.to(dtype)
context = context.to(dtype=dtype, device=device)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype)
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra