mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Working multi gpu training. Still need a lot of tweaks and testing.
This commit is contained in:
@@ -61,6 +61,11 @@ from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, Netw
|
||||
DecoratorConfig
|
||||
from toolkit.logging import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.print import print_acc
|
||||
from accelerate import Accelerator
|
||||
import transformers
|
||||
import diffusers
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -71,6 +76,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
|
||||
super().__init__(process_id, job, config)
|
||||
self.accelerator: Accelerator = get_accelerator()
|
||||
if self.accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
self.sd: StableDiffusion
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
@@ -82,8 +95,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
self.is_grad_accumulation_step = False
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.device = str(self.accelerator.device)
|
||||
self.device_torch = self.accelerator.device
|
||||
network_config = self.get_conf('network', None)
|
||||
if network_config is not None:
|
||||
self.network_config = NetworkConfig(**network_config)
|
||||
@@ -91,6 +104,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network_config = None
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
model_config = self.get_conf('model', {})
|
||||
self.modules_being_trained: List[torch.nn.Module] = []
|
||||
|
||||
# update modelconfig dtype to match train
|
||||
model_config['dtype'] = self.train_config.dtype
|
||||
@@ -222,6 +236,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return generate_image_config_list
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
flush()
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -316,6 +332,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
elif self.model_config.is_xl:
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
elif self.model_config.is_flux:
|
||||
o_dict['ss_base_model_version'] = 'flux.1'
|
||||
else:
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
@@ -344,6 +362,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return info
|
||||
|
||||
def clean_up_saves(self):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
latest_item = None
|
||||
@@ -400,7 +420,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
print_acc(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
@@ -418,6 +438,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
flush()
|
||||
if self.ema is not None:
|
||||
# always save params as ema
|
||||
@@ -594,10 +616,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
state_dict = self.optimizer.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Could not save optimizer")
|
||||
print_acc(e)
|
||||
print_acc("Could not save optimizer")
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
print_acc(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
@@ -619,7 +641,49 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return params
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.logger.start()
|
||||
if self.accelerator.is_main_process:
|
||||
self.logger.start()
|
||||
self.prepare_accelerator()
|
||||
|
||||
|
||||
def prepare_accelerator(self):
|
||||
# set some config
|
||||
self.accelerator.even_batches=False
|
||||
|
||||
# # prepare all the models stuff for accelerator (hopefully we dont miss any)
|
||||
if self.sd.vae is not None:
|
||||
self.sd.vae = self.accelerator.prepare(self.sd.vae)
|
||||
if self.sd.unet is not None:
|
||||
self.sd.unet = self.accelerator.prepare(self.sd.unet)
|
||||
# todo always tdo it?
|
||||
self.modules_being_trained.append(self.sd.unet)
|
||||
if self.sd.text_encoder is not None and self.train_config.train_text_encoder:
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
self.sd.text_encoder = [self.accelerator.prepare(model) for model in self.sd.text_encoder]
|
||||
self.modules_being_trained.extend(self.sd.text_encoder)
|
||||
else:
|
||||
self.sd.text_encoder = self.accelerator.prepare(self.sd.text_encoder)
|
||||
self.modules_being_trained.append(self.sd.text_encoder)
|
||||
if self.sd.refiner_unet is not None and self.train_config.train_refiner:
|
||||
self.sd.refiner_unet = self.accelerator.prepare(self.sd.refiner_unet)
|
||||
self.modules_being_trained.append(self.sd.refiner_unet)
|
||||
# todo, do we need to do the network or will "unet" get it?
|
||||
if self.sd.network is not None:
|
||||
self.sd.network = self.accelerator.prepare(self.sd.network)
|
||||
self.modules_being_trained.append(self.sd.network)
|
||||
if self.adapter is not None and self.adapter_config.train:
|
||||
# todo adapters may not be a module. need to check
|
||||
self.adapter = self.accelerator.prepare(self.adapter)
|
||||
self.modules_being_trained.append(self.adapter)
|
||||
|
||||
# prepare other things
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)
|
||||
# self.data_loader = self.accelerator.prepare(self.data_loader)
|
||||
# if self.data_loader_reg is not None:
|
||||
# self.data_loader_reg = self.accelerator.prepare(self.data_loader_reg)
|
||||
|
||||
|
||||
def ensure_params_requires_grad(self, force=False):
|
||||
if self.train_config.do_paramiter_swapping and not force:
|
||||
@@ -692,6 +756,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return latest_path
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
meta = None
|
||||
# if path is folder, then it is diffusers
|
||||
if os.path.isdir(path):
|
||||
@@ -708,7 +774,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
print_acc(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
@@ -716,7 +782,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.load_training_state_from_metadata(path)
|
||||
return extra_weights
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
print_acc("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def apply_snr(self, seperated_loss, timesteps):
|
||||
@@ -747,7 +813,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
print_acc(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
|
||||
@@ -1244,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
if latest_save_path is not None and not is_control_net:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
print_acc(f"Loading adapter from {latest_save_path}")
|
||||
if is_t2i:
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
@@ -1290,7 +1356,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
@@ -1357,7 +1423,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# block.attn.set_processor(processor)
|
||||
|
||||
# except ImportError:
|
||||
# print("sage attention is not installed. Using SDP instead")
|
||||
# print_acc("sage attention is not installed. Using SDP instead")
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
if self.sd.is_flux:
|
||||
@@ -1531,8 +1597,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
extra_weights = None
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
print_acc(f"Loading from {latest_save_path}")
|
||||
extra_weights = self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
@@ -1665,17 +1731,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
previous_lrs.append(group['lr'])
|
||||
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
|
||||
# update the optimizer LR from the params
|
||||
print(f"Updating optimizer LR from params")
|
||||
print_acc(f"Updating optimizer LR from params")
|
||||
if len(previous_lrs) > 0:
|
||||
for i, group in enumerate(optimizer.param_groups):
|
||||
group['lr'] = previous_lrs[i]
|
||||
@@ -1711,24 +1777,27 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
|
||||
self.print("Generating first sample from first sample config")
|
||||
print_acc("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
print_acc("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
print_acc("Generating baseline samples before training")
|
||||
self.sample(self.step_num)
|
||||
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
self.progress_bar.pause()
|
||||
|
||||
if self.accelerator.is_local_main_process:
|
||||
self.progress_bar = ToolkitProgressBar(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True,
|
||||
initial=self.step_num,
|
||||
iterable=range(0, self.train_config.steps),
|
||||
)
|
||||
self.progress_bar.pause()
|
||||
else:
|
||||
self.progress_bar = None
|
||||
|
||||
if self.data_loader is not None:
|
||||
dataloader = self.data_loader
|
||||
@@ -1753,7 +1822,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
|
||||
# print(f"Compiling Model")
|
||||
# print_acc(f"Compiling Model")
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
# make sure all params require grad
|
||||
@@ -1779,7 +1848,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.is_grad_accumulation_step = True
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
@@ -1802,13 +1872,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
@@ -1817,7 +1889,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
@@ -1827,7 +1900,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
batch_list.append(batch)
|
||||
@@ -1849,8 +1923,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
@@ -1880,7 +1954,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
@@ -1889,8 +1964,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
if is_sample_step or is_save_step:
|
||||
self.accelerator.wait_for_everyone()
|
||||
if is_sample_step:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
flush()
|
||||
# print above the progress bar
|
||||
if self.train_config.free_u:
|
||||
@@ -1902,57 +1980,70 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
self.accelerator
|
||||
# print above the progress bar
|
||||
self.progress_bar.pause()
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
print_acc(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
with self.timer('log_to_tensorboard'):
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
if self.accelerator.is_main_process:
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# log to logger
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
if self.accelerator.is_main_process:
|
||||
# log to logger
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
elif self.logging_config.log_every is None:
|
||||
# log every step
|
||||
self.logger.log({
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
elif self.logging_config.log_every is None:
|
||||
if self.accelerator.is_main_process:
|
||||
# log every step
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
'learning_rate': learning_rate,
|
||||
})
|
||||
for key, value in loss_dict.items():
|
||||
self.logger.log({
|
||||
f'loss/{key}': value,
|
||||
})
|
||||
|
||||
|
||||
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.pause()
|
||||
# print the timers and clear them
|
||||
self.timer.print()
|
||||
self.timer.reset()
|
||||
self.progress_bar.unpause()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# commit log
|
||||
self.logger.commit(step=self.step_num)
|
||||
if self.accelerator.is_main_process:
|
||||
self.logger.commit(step=self.step_num)
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
|
||||
#############################
|
||||
# End of step
|
||||
@@ -1966,16 +2057,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
###################################################################
|
||||
## END TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
self.progress_bar.close()
|
||||
self.accelerator.wait_for_everyone()
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
if not self.train_config.disable_sampling:
|
||||
self.sample(self.step_num)
|
||||
self.logger.commit(step=self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
self.logger.finish()
|
||||
print_acc("")
|
||||
if self.accelerator.is_main_process:
|
||||
self.save()
|
||||
self.logger.finish()
|
||||
self.accelerator.end_training()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
@@ -2001,6 +2095,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
repo_id: str,
|
||||
private: bool = False,
|
||||
):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
readme_content = self._generate_readme(repo_id)
|
||||
readme_path = os.path.join(self.save_root, "README.md")
|
||||
with open(readme_path, "w", encoding="utf-8") as f:
|
||||
|
||||
Reference in New Issue
Block a user