8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -1353,7 +1353,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
**network_kwargs
)
self.network.force_to(self.device_torch, dtype=dtype)
# todo switch everything to proper mixed precision like this
self.network.force_to(self.device_torch, dtype=torch.float32)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
@@ -1365,6 +1367,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.train_unet
)
# we cannot merge in if quantized
if self.model_config.quantize:
# todo find a way around this
self.network.can_merge_in = False
if is_lorm:
self.network.is_lorm = True
# make sure it is on the right device