Work on supporting flex.2 potential arch

This commit is contained in:
Jaret Burkett
2025-02-17 14:10:25 -07:00
parent 1f7784510d
commit 4af6c5cf30
4 changed files with 918 additions and 23 deletions

View File

@@ -275,6 +275,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
return adapter_tensors
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
if isinstance(batch, list):
batch = batch[0]
# set to eval mode
self.sd.set_device_state(self.eval_slider_device_state)
with torch.no_grad():
@@ -364,10 +366,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
current_timestep = timesteps
else:
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
if self.train_config.noise_scheduler == 'flowmatch':
linear_timesteps = any([
self.train_config.linear_timesteps,
self.train_config.linear_timesteps2,
self.train_config.timestep_type == 'linear',
])
timestep_type = 'linear' if linear_timesteps else None
if timestep_type is None:
timestep_type = self.train_config.timestep_type
# make fake latents
l = torch.randn(
true_batch_size, 16, height, width
).to(self.device_torch, dtype=dtype)
self.sd.noise_scheduler.set_train_timesteps(
self.train_config.max_denoising_steps,
device=self.device_torch,
timestep_type=timestep_type,
latents=l
)
else:
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
# ger a random number of steps
timesteps_to = torch.randint(