Working multi gpu training. Still need a lot of tweaks and testing.

This commit is contained in:
Jaret Burkett
2025-01-25 16:46:20 -07:00
parent 441474e81f
commit 5e663746b8
9 changed files with 432 additions and 294 deletions

View File

@@ -20,6 +20,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, Guid
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.print import print_acc
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
@@ -59,8 +60,6 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.scaler = torch.cuda.amp.GradScaler()
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
self.do_grad_scale = True
@@ -70,12 +69,12 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter_config.train:
self.do_grad_scale = False
if self.train_config.dtype in ["fp16", "float16"]:
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
# if self.train_config.dtype in ["fp16", "float16"]:
# # patch the scaler to allow fp16 training
# org_unscale_grads = self.scaler._unscale_grads_
# def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
# return org_unscale_grads(optimizer, inv_scale, found_inf, True)
# self.scaler._unscale_grads_ = _unscale_grads_replacer
self.cached_blank_embeds: Optional[PromptEmbeds] = None
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
@@ -168,11 +167,11 @@ class SDTrainer(BaseSDTrainProcess):
raise ValueError("Cannot unload text encoder if training text encoder")
# cache embeddings
print("\n***** UNLOADING TEXT ENCODER *****")
print("This will train only with a blank prompt or trigger word, if set")
print("If this is not what you want, remove the unload_text_encoder flag")
print("***********************************")
print("")
print_acc("\n***** UNLOADING TEXT ENCODER *****")
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
print_acc("***********************************")
print_acc("")
self.sd.text_encoder_to(self.device_torch)
self.cached_blank_embeds = self.sd.encode_prompt("")
if self.trigger_word is not None:
@@ -484,7 +483,7 @@ class SDTrainer(BaseSDTrainProcess):
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
print("Prior loss is nan")
print_acc("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
@@ -553,7 +552,6 @@ class SDTrainer(BaseSDTrainProcess):
noise=noise,
sd=self.sd,
unconditional_embeds=unconditional_embeds,
scaler=self.scaler,
**kwargs
)
@@ -668,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess):
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
self.accelerator.backward(loss)
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
@@ -823,7 +821,7 @@ class SDTrainer(BaseSDTrainProcess):
# loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()
self.accelerator.backward(loss)
# detach it so parent class can run backward on no grads without throwing error
loss = loss.detach()
@@ -1446,8 +1444,8 @@ class SDTrainer(BaseSDTrainProcess):
quad_count=quad_count
)
else:
print("No Clip Image")
print([file_item.path for file_item in batch.file_items])
print_acc("No Clip Image")
print_acc([file_item.path for file_item in batch.file_items])
raise ValueError("Could not find clip image")
if not self.adapter_config.train_image_encoder:
@@ -1625,7 +1623,7 @@ class SDTrainer(BaseSDTrainProcess):
)
# check if nan
if torch.isnan(loss):
print("loss is nan")
print_acc("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
@@ -1640,10 +1638,7 @@ class SDTrainer(BaseSDTrainProcess):
# if self.is_bfloat:
# loss.backward()
# else:
if not self.do_grad_scale:
loss.backward()
else:
self.scaler.scale(loss).backward()
self.accelerator.backward(loss)
return loss.detach()
# flush()
@@ -1668,21 +1663,14 @@ class SDTrainer(BaseSDTrainProcess):
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if self.do_grad_scale:
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# self.optimizer.step()
if not self.do_grad_scale:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.adapter and isinstance(self.adapter, CustomAdapter):