mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
self.network_config = None
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
model_config = self.get_conf('model', {})
|
||||
|
||||
# update modelconfig dtype to match train
|
||||
model_config['dtype'] = self.train_config.dtype
|
||||
self.model_config = ModelConfig(**model_config)
|
||||
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||
first_sample_config = self.get_conf('first_sample', None)
|
||||
@@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
dtype=noise.dtype) * 2 - 1
|
||||
|
||||
# multiply by shift amount
|
||||
noise_shift *= self.train_config.random_noise_shift
|
||||
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
|
||||
return noise
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
Reference in New Issue
Block a user