Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs

This commit is contained in:
Jaret Burkett
2023-08-27 09:40:01 -06:00
parent 6bd3851058
commit 9b164a8688
5 changed files with 190 additions and 103 deletions

View File

@@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2"
class BlankNetwork:
multiplier = 1.0
is_active = True
def __init__(self):
self.multiplier = 1.0
self.is_active = True
self.is_normalizing = False
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
pass
def __enter__(self):
@@ -258,6 +261,12 @@ class StableDiffusion:
else:
network = BlankNetwork()
was_network_normalizing = network.is_normalizing
# apply the normalizer if it is normalizing before inference and disable it
if network.is_normalizing:
network.apply_stored_normalizer()
network.is_normalizing = False
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
@@ -377,6 +386,7 @@ class StableDiffusion:
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
self.network.is_normalizing = was_network_normalizing
# self.tokenizer.to(original_device_dict['tokenizer'])
def get_latent_noise(