Added some experimental training techniques. Ignore for now. Still in testing.

This commit is contained in:
Jaret Burkett
2025-05-21 02:19:54 -06:00
parent 01101be196
commit e5181d23cd
6 changed files with 240 additions and 43 deletions

View File

@@ -142,7 +142,9 @@ class StableDiffusion:
):
self.accelerator = get_accelerator()
self.custom_pipeline = custom_pipeline
self.device = device
self.device = str(device)
if "cuda" in self.device and ":" not in self.device:
self.device = f"{self.device}:0"
self.device_torch = torch.device(device)
self.dtype = dtype
self.torch_dtype = get_torch_dtype(dtype)
@@ -2086,7 +2088,10 @@ class StableDiffusion:
noise_pred = noise_pred
else:
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
try:
self.unet.to(self.device_torch)
except Exception as e:
pass
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
if self.is_flux: