Added a guidance burning loss. Modified DFE to work with new model. Bug fixes

This commit is contained in:
Jaret Burkett
2025-06-23 08:38:27 -06:00
parent 8602470952
commit ba1274d99e
5 changed files with 106 additions and 99 deletions

View File

@@ -1012,7 +1012,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)

View File

@@ -113,6 +113,8 @@ class GenerateProcess(BaseProcess):
prompt_image_configs = []
for _ in range(self.generate_config.num_repeats):
for prompt in self.generate_config.prompts:
# remove --
prompt = prompt.replace('--', '').strip()
width = self.generate_config.width
height = self.generate_config.height
# prompt = self.clean_prompt(prompt)