mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
git status
This commit is contained in:
@@ -70,6 +70,8 @@ class SampleItem:
|
||||
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
|
||||
self.network_multiplier = 1.0
|
||||
|
||||
# only for models that support it, (qwen image edit 2509 for now)
|
||||
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
|
||||
|
||||
class SampleConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -104,6 +106,8 @@ class SampleConfig:
|
||||
]
|
||||
raw_samples = kwargs.get('samples', default_samples_kwargs)
|
||||
self.samples = [SampleItem(self, **item) for item in raw_samples]
|
||||
# only for models that support it, (qwen image edit 2509 for now)
|
||||
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
|
||||
|
||||
@property
|
||||
def prompts(self):
|
||||
@@ -993,7 +997,8 @@ class GenerateImageConfig:
|
||||
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
ctrl_idx: int = 0,
|
||||
do_cfg_norm: bool = False,
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -1063,6 +1068,8 @@ class GenerateImageConfig:
|
||||
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.do_cfg_norm: bool = do_cfg_norm
|
||||
|
||||
def set_gen_time(self, gen_time: int = None):
|
||||
if gen_time is not None:
|
||||
|
||||
1
toolkit/memory_management/__init__.py
Normal file
1
toolkit/memory_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .manager import MemoryManager
|
||||
12
toolkit/memory_management/manager.py
Normal file
12
toolkit/memory_management/manager.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(
|
||||
self,
|
||||
model: "BaseModel",
|
||||
):
|
||||
self.model: "BaseModel" = model
|
||||
@@ -41,6 +41,7 @@ from torchvision.transforms import functional as TF
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.memory_management import MemoryManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -185,6 +186,8 @@ class BaseModel:
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
self.memory_manager = MemoryManager(self)
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -70,6 +70,7 @@ from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from diffusers import FluxFillPipeline
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
|
||||
from toolkit.memory_management import MemoryManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -224,6 +225,8 @@ class StableDiffusion:
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
self.memory_manager = MemoryManager(self)
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
|
||||
Reference in New Issue
Block a user