mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
8 bit training working on flux
This commit is contained in:
@@ -1353,7 +1353,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
# todo switch everything to proper mixed precision like this
|
||||
self.network.force_to(self.device_torch, dtype=torch.float32)
|
||||
# give network to sd so it can use it
|
||||
self.sd.network = self.network
|
||||
self.network._update_torch_multiplier()
|
||||
@@ -1365,6 +1367,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.train_unet
|
||||
)
|
||||
|
||||
# we cannot merge in if quantized
|
||||
if self.model_config.quantize:
|
||||
# todo find a way around this
|
||||
self.network.can_merge_in = False
|
||||
|
||||
if is_lorm:
|
||||
self.network.is_lorm = True
|
||||
# make sure it is on the right device
|
||||
|
||||
Reference in New Issue
Block a user