mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Pixel shuffle adapter. Some bug fixes thrown in
This commit is contained in:
@@ -59,7 +59,7 @@ from tqdm import tqdm
|
||||
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
|
||||
DecoratorConfig
|
||||
from toolkit.logging import create_logger
|
||||
from toolkit.logging_aitk import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.print import print_acc
|
||||
@@ -578,7 +578,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
direct_save = True
|
||||
if self.adapter_config.type == 'redux':
|
||||
direct_save = True
|
||||
if self.adapter_config.type == 'control_lora':
|
||||
if self.adapter_config.type in ['control_lora', 'subpixel']:
|
||||
direct_save = True
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
@@ -918,6 +918,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
num_channels=latents.shape[1],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user