Added a guidance burning loss. Modified DFE to work with new model. Bug fixes

This commit is contained in:
Jaret Burkett
2025-06-23 08:38:27 -06:00
parent 8602470952
commit ba1274d99e
5 changed files with 106 additions and 99 deletions

View File

@@ -815,7 +815,10 @@ class BaseModel:
# predict the noise residual
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
try:
self.unet.to(self.device_torch)
except Exception as e:
pass
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)