From ce8e7a1271ec78ff99666f81e9a035e3e74c98a9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 22 Jul 2023 00:05:15 -0600 Subject: [PATCH] Added prompt flag to adjust network multiplier --- jobs/process/TrainSliderProcess.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 6e260940..1fa6d13c 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -67,6 +67,7 @@ class SampleConfig: self.walk_seed = kwargs.get('walk_seed', False) self.guidance_scale = kwargs.get('guidance_scale', 7) self.sample_steps = kwargs.get('sample_steps', 20) + self.network_multiplier = kwargs.get('network_multiplier', 1) class NetworkConfig: @@ -214,6 +215,7 @@ class TrainSliderProcess(BaseTrainProcess): pipeline.set_progress_bar_config(disable=True) start_seed = self.sample_config.seed + start_multiplier = self.network.multiplier current_seed = start_seed pipeline.to(self.device_torch) @@ -228,12 +230,22 @@ class TrainSliderProcess(BaseTrainProcess): for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"): raw_prompt = self.sample_config.prompts[i] - prompt = raw_prompt + neg = self.sample_config.neg - p_split = raw_prompt.split('--n') + multiplier = self.sample_config.network_multiplier + p_split = raw_prompt.split('--') + prompt = p_split[0].strip() + if len(p_split) > 1: - prompt = p_split[0].strip() - neg = p_split[1].strip() + for split in p_split: + flag = split[:1] + content = split[1:].strip() + if flag == 'n': + neg = content + elif flag == 'm': + # multiplier + multiplier = float(content) + height = self.sample_config.height width = self.sample_config.width height = max(64, height - height % 8) # round to divisible by 8 @@ -242,6 +254,7 @@ class TrainSliderProcess(BaseTrainProcess): if self.sample_config.walk_seed: current_seed += i + self.network.multiplier = multiplier torch.manual_seed(current_seed) torch.cuda.manual_seed(current_seed) @@ -278,6 +291,7 @@ class TrainSliderProcess(BaseTrainProcess): self.sd.unet.to(original_device_dict['unet']) self.sd.text_encoder.to(original_device_dict['text_encoder']) self.network.train() + self.network.multiplier = start_multiplier # self.sd.tokenizer.to(original_device_dict['tokenizer']) def update_training_metadata(self):