mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Working multi gpu training. Still need a lot of tweaks and testing.
This commit is contained in:
@@ -20,6 +20,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, Guid
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
@@ -59,8 +60,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
self.do_grad_scale = True
|
||||
@@ -70,12 +69,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter_config.train:
|
||||
self.do_grad_scale = False
|
||||
|
||||
if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
# if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# # patch the scaler to allow fp16 training
|
||||
# org_unscale_grads = self.scaler._unscale_grads_
|
||||
# def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
# return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
# self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.cached_blank_embeds: Optional[PromptEmbeds] = None
|
||||
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
|
||||
@@ -168,11 +167,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("Cannot unload text encoder if training text encoder")
|
||||
# cache embeddings
|
||||
|
||||
print("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print("This will train only with a blank prompt or trigger word, if set")
|
||||
print("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print("***********************************")
|
||||
print("")
|
||||
print_acc("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print_acc("***********************************")
|
||||
print_acc("")
|
||||
self.sd.text_encoder_to(self.device_torch)
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("")
|
||||
if self.trigger_word is not None:
|
||||
@@ -484,7 +483,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
if torch.isnan(prior_loss).any():
|
||||
print("Prior loss is nan")
|
||||
print_acc("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
@@ -553,7 +552,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise=noise,
|
||||
sd=self.sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
scaler=self.scaler,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -668,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
@@ -823,7 +821,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# loss = self.apply_snr(loss, timesteps)
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
# detach it so parent class can run backward on no grads without throwing error
|
||||
loss = loss.detach()
|
||||
@@ -1446,8 +1444,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
quad_count=quad_count
|
||||
)
|
||||
else:
|
||||
print("No Clip Image")
|
||||
print([file_item.path for file_item in batch.file_items])
|
||||
print_acc("No Clip Image")
|
||||
print_acc([file_item.path for file_item in batch.file_items])
|
||||
raise ValueError("Could not find clip image")
|
||||
|
||||
if not self.adapter_config.train_image_encoder:
|
||||
@@ -1625,7 +1623,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
print("loss is nan")
|
||||
print_acc("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
with self.timer('backward'):
|
||||
@@ -1640,10 +1638,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if self.is_bfloat:
|
||||
# loss.backward()
|
||||
# else:
|
||||
if not self.do_grad_scale:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return loss.detach()
|
||||
# flush()
|
||||
@@ -1668,21 +1663,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
if self.do_grad_scale:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# self.optimizer.step()
|
||||
if not self.do_grad_scale:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
|
||||
Reference in New Issue
Block a user