mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added ability to set timesteps to linear for flowmatching schedule
This commit is contained in:
@@ -908,7 +908,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
)
|
)
|
||||||
elif self.train_config.noise_scheduler == 'flowmatch':
|
elif self.train_config.noise_scheduler == 'flowmatch':
|
||||||
self.sd.noise_scheduler.set_train_timesteps(
|
self.sd.noise_scheduler.set_train_timesteps(
|
||||||
num_train_timesteps, device=self.device_torch
|
num_train_timesteps,
|
||||||
|
device=self.device_torch,
|
||||||
|
linear=self.train_config.linear_timesteps
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sd.noise_scheduler.set_timesteps(
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
|
|||||||
@@ -356,6 +356,7 @@ class TrainConfig:
|
|||||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||||
|
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
|
|||||||
@@ -44,17 +44,22 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def set_train_timesteps(self, num_timesteps, device):
|
def set_train_timesteps(self, num_timesteps, device, linear=False):
|
||||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
if linear:
|
||||||
# Generate values from 0 to 1
|
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
self.timesteps = timesteps
|
||||||
|
return timesteps
|
||||||
|
else:
|
||||||
|
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||||
|
# Generate values from 0 to 1
|
||||||
|
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||||
|
|
||||||
# Scale and reverse the values to go from 1000 to 0
|
# Scale and reverse the values to go from 1000 to 0
|
||||||
timesteps = ((1 - t) * 1000)
|
timesteps = ((1 - t) * 1000)
|
||||||
|
|
||||||
# Sort the timesteps in descending order
|
# Sort the timesteps in descending order
|
||||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||||
|
|
||||||
self.timesteps = timesteps.to(device=device)
|
self.timesteps = timesteps.to(device=device)
|
||||||
|
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|||||||
Reference in New Issue
Block a user