mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Huge memory optimizations, many big fixes
This commit is contained in:
@@ -14,12 +14,9 @@ def flush():
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
pass
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -40,6 +37,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
@@ -71,7 +69,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
|
||||
Reference in New Issue
Block a user