Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -83,7 +83,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
self.network_config = None
self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {}))
model_config = self.get_conf('model', {})
# update modelconfig dtype to match train
model_config['dtype'] = self.train_config.dtype
self.model_config = ModelConfig(**model_config)
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
first_sample_config = self.get_conf('first_sample', None)
@@ -723,6 +728,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
dtype=noise.dtype) * 2 - 1
# multiply by shift amount
noise_shift *= self.train_config.random_noise_shift
# add to noise
noise += noise_shift
return noise
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):