mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Huge memory optimizations, many big fixes
This commit is contained in:
@@ -9,6 +9,7 @@ from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -31,11 +32,12 @@ def flush():
|
||||
|
||||
|
||||
class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd: StableDiffusion
|
||||
embedding: Union[Embedding, None] = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.sd: StableDiffusion
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
@@ -344,7 +346,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
noise.requires_grad = False
|
||||
noise = noise.detach()
|
||||
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user