Bugfixes and cleanup

This commit is contained in:
Jaret Burkett
2024-08-01 11:45:12 -06:00
parent 47744373f2
commit 03613c523f
4 changed files with 39 additions and 127 deletions

View File

@@ -807,10 +807,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_latents'):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if self.train_config.img_multiplier is not None:
# dont adjust for regs.
if self.train_config.img_multiplier is not None and not is_reg:
# do it ad contrast
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
if batch.latents is not None:
@@ -1495,8 +1497,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
flush()
except Exception as e:
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)