Added a guidance burning loss. Modified DFE to work with new model. Bug fixes

This commit is contained in:
Jaret Burkett
2025-06-23 08:38:27 -06:00
parent 8602470952
commit ba1274d99e
5 changed files with 106 additions and 99 deletions

View File

@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)