Working multi gpu training. Still need a lot of tweaks and testing.

This commit is contained in:
Jaret Burkett
2025-01-25 16:46:20 -07:00
parent 441474e81f
commit 5e663746b8
9 changed files with 432 additions and 294 deletions

View File

@@ -61,6 +61,11 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
DecoratorConfig
from toolkit.logging import create_logger
from diffusers import FluxTransformer2DModel
from toolkit.accelerator import get_accelerator
from toolkit.print import print_acc
from accelerate import Accelerator
import transformers
import diffusers
def flush():
torch.cuda.empty_cache()
@@ -71,6 +76,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
super().__init__(process_id, job, config)
self.accelerator: Accelerator = get_accelerator()
if self.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
self.sd: StableDiffusion
self.embedding: Union[Embedding, None] = None
@@ -82,8 +95,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.grad_accumulation_step = 1
# if true, then we do not do an optimizer step. We are accumulating gradients
self.is_grad_accumulation_step = False
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.device = str(self.accelerator.device)
self.device_torch = self.accelerator.device
network_config = self.get_conf('network', None)
if network_config is not None:
self.network_config = NetworkConfig(**network_config)
@@ -91,6 +104,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network_config = None
self.train_config = TrainConfig(**self.get_conf('train', {}))
model_config = self.get_conf('model', {})
self.modules_being_trained: List[torch.nn.Module] = []
# update modelconfig dtype to match train
model_config['dtype'] = self.train_config.dtype
@@ -222,6 +236,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
return generate_image_config_list
def sample(self, step=None, is_first=False):
if not self.accelerator.is_main_process:
return
flush()
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -316,6 +332,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
elif self.model_config.is_xl:
o_dict['ss_base_model_version'] = 'sdxl_1.0'
elif self.model_config.is_flux:
o_dict['ss_base_model_version'] = 'flux.1'
else:
o_dict['ss_base_model_version'] = 'sd_1.5'
@@ -344,6 +362,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
return info
def clean_up_saves(self):
if not self.accelerator.is_main_process:
return
# remove old saves
# get latest saved step
latest_item = None
@@ -400,7 +420,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
items_to_remove = list(dict.fromkeys(items_to_remove))
for item in items_to_remove:
self.print(f"Removing old save: {item}")
print_acc(f"Removing old save: {item}")
if os.path.isdir(item):
shutil.rmtree(item)
else:
@@ -418,6 +438,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
pass
def save(self, step=None):
if not self.accelerator.is_main_process:
return
flush()
if self.ema is not None:
# always save params as ema
@@ -594,10 +616,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
state_dict = self.optimizer.state_dict()
torch.save(state_dict, file_path)
except Exception as e:
print(e)
print("Could not save optimizer")
print_acc(e)
print_acc("Could not save optimizer")
self.print(f"Saved to {file_path}")
print_acc(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -619,7 +641,49 @@ class BaseSDTrainProcess(BaseTrainProcess):
return params
def hook_before_train_loop(self):
self.logger.start()
if self.accelerator.is_main_process:
self.logger.start()
self.prepare_accelerator()
def prepare_accelerator(self):
# set some config
self.accelerator.even_batches=False
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
if self.sd.vae is not None:
self.sd.vae = self.accelerator.prepare(self.sd.vae)
if self.sd.unet is not None:
self.sd.unet = self.accelerator.prepare(self.sd.unet)
# todo always tdo it?
self.modules_being_trained.append(self.sd.unet)
if self.sd.text_encoder is not None and self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder]
self.modules_being_trained.extend(self.sd.text_encoder)
else:
self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder)
self.modules_being_trained.append(self.sd.text_encoder)
if self.sd.refiner_unet is not None and self.train_config.train_refiner:
self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet)
self.modules_being_trained.append(self.sd.refiner_unet)
# todo, do we need to do the network or will "unet" get it?
if self.sd.network is not None:
self.sd.network = self.accelerator.prepare(self.sd.network)
self.modules_being_trained.append(self.sd.network)
if self.adapter is not None and self.adapter_config.train:
# todo adapters may not be a module. need to check
self.adapter = self.accelerator.prepare(self.adapter)
self.modules_being_trained.append(self.adapter)
# prepare other things
self.optimizer = self.accelerator.prepare(self.optimizer)
if self.lr_scheduler is not None:
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)
# self.data_loader = self.accelerator.prepare(self.data_loader)
# if self.data_loader_reg is not None:
# self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg)
def ensure_params_requires_grad(self, force=False):
if self.train_config.do_paramiter_swapping and not force:
@@ -692,6 +756,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
return latest_path
def load_training_state_from_metadata(self, path):
if not self.accelerator.is_main_process:
return
meta = None
# if path is folder, then it is diffusers
if os.path.isdir(path):
@@ -708,7 +774,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
print_acc(f"Found step {self.step_num} in metadata, starting from there")
def load_weights(self, path):
if self.network is not None:
@@ -716,7 +782,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.load_training_state_from_metadata(path)
return extra_weights
else:
print("load_weights not implemented for non-network models")
print_acc("load_weights not implemented for non-network models")
return None
def apply_snr(self, seperated_loss, timesteps):
@@ -747,7 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
print_acc(f"Found step {self.step_num} in metadata, starting from there")
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
@@ -1244,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.adapter.to(self.device_torch, dtype=dtype)
if latest_save_path is not None and not is_control_net:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
print_acc(f"Loading adapter from {latest_save_path}")
if is_t2i:
loaded_state_dict = load_t2i_model(
latest_save_path,
@@ -1290,7 +1356,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
self.load_training_state_from_metadata(latest_save_path)
@@ -1357,7 +1423,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# block.attn.set_processor(processor)
# except ImportError:
# print("sage attention is not installed. Using SDP instead")
# print_acc("sage attention is not installed. Using SDP instead")
if self.train_config.gradient_checkpointing:
if self.sd.is_flux:
@@ -1531,8 +1597,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
latest_save_path = self.get_latest_save_path(lora_name)
extra_weights = None
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
print_acc(f"Loading from {latest_save_path}")
extra_weights = self.load_weights(latest_save_path)
self.network.multiplier = 1.0
@@ -1665,17 +1731,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
previous_lrs.append(group['lr'])
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
flush()
except Exception as e:
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
print_acc(e)
# update the optimizer LR from the params
print(f"Updating optimizer LR from params")
print_acc(f"Updating optimizer LR from params")
if len(previous_lrs) > 0:
for i, group in enumerate(optimizer.param_groups):
group['lr'] = previous_lrs[i]
@@ -1711,24 +1777,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.hook_before_train_loop()
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
self.print("Generating first sample from first sample config")
print_acc("Generating first sample from first sample config")
self.sample(0, is_first=True)
# sample first
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
self.print("Skipping first sample due to config setting")
print_acc("Skipping first sample due to config setting")
elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training")
print_acc("Generating baseline samples before training")
self.sample(self.step_num)
self.progress_bar = ToolkitProgressBar(
total=self.train_config.steps,
desc=self.job.name,
leave=True,
initial=self.step_num,
iterable=range(0, self.train_config.steps),
)
self.progress_bar.pause()
if self.accelerator.is_local_main_process:
self.progress_bar = ToolkitProgressBar(
total=self.train_config.steps,
desc=self.job.name,
leave=True,
initial=self.step_num,
iterable=range(0, self.train_config.steps),
)
self.progress_bar.pause()
else:
self.progress_bar = None
if self.data_loader is not None:
dataloader = self.data_loader
@@ -1753,7 +1822,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
# self.step_num = 0
# print(f"Compiling Model")
# print_acc(f"Compiling Model")
# torch.compile(self.sd.unet, dynamic=True)
# make sure all params require grad
@@ -1779,7 +1848,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.is_grad_accumulation_step = True
if self.train_config.free_u:
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
@@ -1802,13 +1872,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
except StopIteration:
with self.timer('reset_batch:reg'):
# hit the end of an epoch, reset
self.progress_bar.pause()
if self.progress_bar is not None:
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
@@ -1817,7 +1889,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
except StopIteration:
with self.timer('reset_batch'):
# hit the end of an epoch, reset
self.progress_bar.pause()
if self.progress_bar is not None:
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
self.epoch_num += 1
@@ -1827,7 +1900,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.grad_accumulation_step = 0
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
else:
batch = None
batch_list.append(batch)
@@ -1849,8 +1923,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
loss_dict = self.hook_train_loop(batch_list)
with self.accelerator.accumulate(self.modules_being_trained):
loss_dict = self.hook_train_loop(batch_list)
self.timer.stop('train_loop')
if not did_first_flush:
flush()
@@ -1880,7 +1954,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
self.progress_bar.set_postfix_str(prog_bar_string)
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
@@ -1889,8 +1964,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# don't do on first step
if self.step_num != self.start_step:
if is_sample_step or is_save_step:
self.accelerator.wait_for_everyone()
if is_sample_step:
self.progress_bar.pause()
if self.progress_bar is not None:
self.progress_bar.pause()
flush()
# print above the progress bar
if self.train_config.free_u:
@@ -1902,57 +1980,70 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
if is_save_step:
self.accelerator
# print above the progress bar
self.progress_bar.pause()
self.print(f"Saving at step {self.step_num}")
if self.progress_bar is not None:
self.progress_bar.pause()
print_acc(f"Saving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
self.progress_bar.pause()
if self.progress_bar is not None:
self.progress_bar.pause()
with self.timer('log_to_tensorboard'):
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.unpause()
if self.accelerator.is_main_process:
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
if self.progress_bar is not None:
self.progress_bar.unpause()
# log to logger
self.logger.log({
'learning_rate': learning_rate,
})
for key, value in loss_dict.items():
if self.accelerator.is_main_process:
# log to logger
self.logger.log({
f'loss/{key}': value,
'learning_rate': learning_rate,
})
elif self.logging_config.log_every is None:
# log every step
self.logger.log({
'learning_rate': learning_rate,
})
for key, value in loss_dict.items():
for key, value in loss_dict.items():
self.logger.log({
f'loss/{key}': value,
})
elif self.logging_config.log_every is None:
if self.accelerator.is_main_process:
# log every step
self.logger.log({
f'loss/{key}': value,
'learning_rate': learning_rate,
})
for key, value in loss_dict.items():
self.logger.log({
f'loss/{key}': value,
})
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
self.progress_bar.pause()
if self.progress_bar is not None:
self.progress_bar.pause()
# print the timers and clear them
self.timer.print()
self.timer.reset()
self.progress_bar.unpause()
if self.progress_bar is not None:
self.progress_bar.unpause()
# commit log
self.logger.commit(step=self.step_num)
if self.accelerator.is_main_process:
self.logger.commit(step=self.step_num)
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
if self.progress_bar is not None:
self.progress_bar.update(step - self.progress_bar.n)
#############################
# End of step
@@ -1966,16 +2057,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
###################################################################
## END TRAIN LOOP
###################################################################
self.progress_bar.close()
self.accelerator.wait_for_everyone()
if self.progress_bar is not None:
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
if not self.train_config.disable_sampling:
self.sample(self.step_num)
self.logger.commit(step=self.step_num)
print("")
self.save()
self.logger.finish()
print_acc("")
if self.accelerator.is_main_process:
self.save()
self.logger.finish()
self.accelerator.end_training()
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
@@ -2001,6 +2095,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
repo_id: str,
private: bool = False,
):
if not self.accelerator.is_main_process:
return
readme_content = self._generate_readme(repo_id)
readme_path = os.path.join(self.save_root, "README.md")
with open(readme_path, "w", encoding="utf-8") as f: