git status

This commit is contained in:
Jaret Burkett
2025-10-01 14:12:17 -06:00
parent b07b88c46b
commit 3086a58e5b
8 changed files with 438 additions and 31 deletions

View File

@@ -70,6 +70,8 @@ class SampleItem:
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
self.network_multiplier = 1.0
# only for models that support it, (qwen image edit 2509 for now)
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
class SampleConfig:
def __init__(self, **kwargs):
@@ -104,6 +106,8 @@ class SampleConfig:
]
raw_samples = kwargs.get('samples', default_samples_kwargs)
self.samples = [SampleItem(self, **item) for item in raw_samples]
# only for models that support it, (qwen image edit 2509 for now)
self.do_cfg_norm: bool = kwargs.get('do_cfg_norm', False)
@property
def prompts(self):
@@ -993,7 +997,8 @@ class GenerateImageConfig:
ctrl_img_3: Optional[str] = None, # third control image for multi control model
num_frames: int = 1,
fps: int = 15,
ctrl_idx: int = 0
ctrl_idx: int = 0,
do_cfg_norm: bool = False,
):
self.width: int = width
self.height: int = height
@@ -1063,6 +1068,8 @@ class GenerateImageConfig:
self.width = max(64, self.width - self.width % 8) # round to divisible by 8
self.logger = logger
self.do_cfg_norm: bool = do_cfg_norm
def set_gen_time(self, gen_time: int = None):
if gen_time is not None:

View File

@@ -0,0 +1 @@
from .manager import MemoryManager

View File

@@ -0,0 +1,12 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from toolkit.models.base_model import BaseModel
class MemoryManager:
def __init__(
self,
model: "BaseModel",
):
self.model: "BaseModel" = model

View File

@@ -41,6 +41,7 @@ from torchvision.transforms import functional as TF
from toolkit.accelerator import get_accelerator, unwrap_model
from typing import TYPE_CHECKING
from toolkit.print import print_acc
from toolkit.memory_management import MemoryManager
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
@@ -185,6 +186,8 @@ class BaseModel:
self.has_multiple_control_images = False
# do not resize control images
self.use_raw_control_images = False
self.memory_manager = MemoryManager(self)
# properties for old arch for backwards compatibility
@property

View File

@@ -70,6 +70,7 @@ from typing import TYPE_CHECKING
from toolkit.print import print_acc
from diffusers import FluxFillPipeline
from transformers import AutoModel, AutoTokenizer, Gemma2Model, Qwen2Model, LlamaModel
from toolkit.memory_management import MemoryManager
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
@@ -224,6 +225,8 @@ class StableDiffusion:
# do not resize control images
self.use_raw_control_images = False
self.memory_manager = MemoryManager(self)
# properties for old arch for backwards compatibility
@property
def is_xl(self):