mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Work on supporting flex.2 potential arch
This commit is contained in:
@@ -275,6 +275,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
|
||||
if isinstance(batch, list):
|
||||
batch = batch[0]
|
||||
# set to eval mode
|
||||
self.sd.set_device_state(self.eval_slider_device_state)
|
||||
with torch.no_grad():
|
||||
@@ -364,10 +366,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
|
||||
current_timestep = timesteps
|
||||
else:
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
linear_timesteps = any([
|
||||
self.train_config.linear_timesteps,
|
||||
self.train_config.linear_timesteps2,
|
||||
self.train_config.timestep_type == 'linear',
|
||||
])
|
||||
|
||||
timestep_type = 'linear' if linear_timesteps else None
|
||||
if timestep_type is None:
|
||||
timestep_type = self.train_config.timestep_type
|
||||
|
||||
# make fake latents
|
||||
l = torch.randn(
|
||||
true_batch_size, 16, height, width
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
self.train_config.max_denoising_steps,
|
||||
device=self.device_torch,
|
||||
timestep_type=timestep_type,
|
||||
latents=l
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
# ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
|
||||
Reference in New Issue
Block a user