mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Added ability to split up flux across gpus (experimental). Changed the way timestep scheduling works to prep for more specific schedules.
This commit is contained in:
@@ -60,7 +60,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -553,6 +553,10 @@ class StableDiffusion:
|
||||
# low_cpu_mem_usage=False,
|
||||
# device_map=None
|
||||
)
|
||||
# hack in model gpu splitter
|
||||
if self.model_config.split_model_over_gpus:
|
||||
add_model_gpu_splitter_to_flux(transformer)
|
||||
|
||||
if not self.low_vram:
|
||||
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user