mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-22 15:29:08 +00:00
Add ability to designate a dataset as i2v or t2v for models that support it
This commit is contained in:
@@ -243,41 +243,42 @@ class Wan22Model(Wan21):
|
||||
# for wan, only do i2v for video for now. Images do normal t2i
|
||||
conditioned_latent = latent_model_input
|
||||
noise_mask = None
|
||||
|
||||
if batch.dataset_config.do_i2v:
|
||||
with torch.no_grad():
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
# Add conditioning using the standalone function
|
||||
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
||||
latent_model_input=latent_model_input.to(
|
||||
self.device_torch, self.torch_dtype
|
||||
),
|
||||
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
||||
vae=self.vae,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
|
||||
with torch.no_grad():
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
# Add conditioning using the standalone function
|
||||
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
||||
latent_model_input=latent_model_input.to(
|
||||
self.device_torch, self.torch_dtype
|
||||
),
|
||||
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
||||
vae=self.vae,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
|
||||
# make the noise mask
|
||||
if noise_mask is None:
|
||||
noise_mask = torch.ones(
|
||||
conditioned_latent.shape,
|
||||
dtype=conditioned_latent.dtype,
|
||||
device=conditioned_latent.device,
|
||||
)
|
||||
# todo write this better
|
||||
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
||||
out_t_chunks = []
|
||||
for t in t_chunks:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
temp_ts = temp_ts.unsqueeze(0)
|
||||
out_t_chunks.append(temp_ts)
|
||||
timestep = torch.cat(out_t_chunks, dim=0)
|
||||
# make the noise mask
|
||||
if noise_mask is None:
|
||||
noise_mask = torch.ones(
|
||||
conditioned_latent.shape,
|
||||
dtype=conditioned_latent.dtype,
|
||||
device=conditioned_latent.device,
|
||||
)
|
||||
# todo write this better
|
||||
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
||||
out_t_chunks = []
|
||||
for t in t_chunks:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
temp_ts = temp_ts.unsqueeze(0)
|
||||
out_t_chunks.append(temp_ts)
|
||||
timestep = torch.cat(out_t_chunks, dim=0)
|
||||
|
||||
noise_pred = self.model(
|
||||
hidden_states=conditioned_latent,
|
||||
|
||||
Reference in New Issue
Block a user