mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Bugfixes and cleanup
This commit is contained in:
@@ -807,10 +807,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with self.timer('prepare_latents'):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.img_multiplier is not None:
|
||||
# dont adjust for regs.
|
||||
if self.train_config.img_multiplier is not None and not is_reg:
|
||||
# do it ad contrast
|
||||
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
|
||||
if batch.latents is not None:
|
||||
@@ -1495,8 +1497,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
Reference in New Issue
Block a user