mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-21 12:53:56 +00:00
Added some experimental training techniques. Ignore for now. Still in testing.
This commit is contained in:
@@ -142,7 +142,9 @@ class StableDiffusion:
|
||||
):
|
||||
self.accelerator = get_accelerator()
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.device = device
|
||||
self.device = str(device)
|
||||
if "cuda" in self.device and ":" not in self.device:
|
||||
self.device = f"{self.device}:0"
|
||||
self.device_torch = torch.device(device)
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(dtype)
|
||||
@@ -2086,7 +2088,10 @@ class StableDiffusion:
|
||||
noise_pred = noise_pred
|
||||
else:
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
try:
|
||||
self.unet.to(self.device_torch)
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
if self.is_flux:
|
||||
|
||||
Reference in New Issue
Block a user