Pixel shuffle adapter. Some bug fixes thrown in

This commit is contained in:
Jaret Burkett
2025-03-29 21:15:01 -06:00
parent b94d7aafea
commit 860d892214
10 changed files with 594 additions and 11 deletions

View File

@@ -59,7 +59,7 @@ from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
DecoratorConfig
from toolkit.logging import create_logger
from toolkit.logging_aitk import create_logger
from diffusers import FluxTransformer2DModel
from toolkit.accelerator import get_accelerator
from toolkit.print import print_acc
@@ -578,7 +578,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
direct_save = True
if self.adapter_config.type == 'redux':
direct_save = True
if self.adapter_config.type == 'control_lora':
if self.adapter_config.type in ['control_lora', 'subpixel']:
direct_save = True
save_ip_adapter_from_diffusers(
state_dict,
@@ -918,6 +918,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
num_channels=latents.shape[1],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)