mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs
This commit is contained in:
@@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
multiplier = 1.0
|
||||
is_active = True
|
||||
|
||||
def __init__(self):
|
||||
self.multiplier = 1.0
|
||||
self.is_active = True
|
||||
self.is_normalizing = False
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
@@ -258,6 +261,12 @@ class StableDiffusion:
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
was_network_normalizing = network.is_normalizing
|
||||
# apply the normalizer if it is normalizing before inference and disable it
|
||||
if network.is_normalizing:
|
||||
network.apply_stored_normalizer()
|
||||
network.is_normalizing = False
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
@@ -377,6 +386,7 @@ class StableDiffusion:
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
self.network.is_normalizing = was_network_normalizing
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def get_latent_noise(
|
||||
|
||||
Reference in New Issue
Block a user