Reworked automagic optimizer and did more testing. Starting to really like it. Working well.

This commit is contained in:
Jaret Burkett
2025-04-28 08:01:10 -06:00
parent 88b3fbae37
commit 2b4c525489
5 changed files with 149 additions and 61 deletions

View File

@@ -62,7 +62,7 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
DecoratorConfig
from toolkit.logging_aitk import create_logger
from diffusers import FluxTransformer2DModel
from toolkit.accelerator import get_accelerator
from toolkit.accelerator import get_accelerator, unwrap_model
from toolkit.print import print_acc
from accelerate import Accelerator
import transformers
@@ -629,7 +629,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
try:
filename = f'optimizer.pt'
file_path = os.path.join(self.save_root, filename)
state_dict = self.optimizer.state_dict()
state_dict = unwrap_model(self.optimizer).state_dict()
torch.save(state_dict, file_path)
print_acc(f"Saved optimizer to {file_path}")
except Exception as e:
@@ -1457,7 +1457,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.load_training_state_from_metadata(previous_refiner_save)
self.sd = ModelClass(
device=self.device,
# todo handle single gpu and multi gpu here
# device=self.device,
device=self.accelerator.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,