Added prompt flag to adjust network multiplier

This commit is contained in:
Jaret Burkett
2023-07-22 00:05:15 -06:00
parent 596e59dd6d
commit ce8e7a1271

View File

@@ -67,6 +67,7 @@ class SampleConfig:
self.walk_seed = kwargs.get('walk_seed', False)
self.guidance_scale = kwargs.get('guidance_scale', 7)
self.sample_steps = kwargs.get('sample_steps', 20)
self.network_multiplier = kwargs.get('network_multiplier', 1)
class NetworkConfig:
@@ -214,6 +215,7 @@ class TrainSliderProcess(BaseTrainProcess):
pipeline.set_progress_bar_config(disable=True)
start_seed = self.sample_config.seed
start_multiplier = self.network.multiplier
current_seed = start_seed
pipeline.to(self.device_torch)
@@ -228,12 +230,22 @@ class TrainSliderProcess(BaseTrainProcess):
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}"):
raw_prompt = self.sample_config.prompts[i]
prompt = raw_prompt
neg = self.sample_config.neg
p_split = raw_prompt.split('--n')
multiplier = self.sample_config.network_multiplier
p_split = raw_prompt.split('--')
prompt = p_split[0].strip()
if len(p_split) > 1:
prompt = p_split[0].strip()
neg = p_split[1].strip()
for split in p_split:
flag = split[:1]
content = split[1:].strip()
if flag == 'n':
neg = content
elif flag == 'm':
# multiplier
multiplier = float(content)
height = self.sample_config.height
width = self.sample_config.width
height = max(64, height - height % 8) # round to divisible by 8
@@ -242,6 +254,7 @@ class TrainSliderProcess(BaseTrainProcess):
if self.sample_config.walk_seed:
current_seed += i
self.network.multiplier = multiplier
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
@@ -278,6 +291,7 @@ class TrainSliderProcess(BaseTrainProcess):
self.sd.unet.to(original_device_dict['unet'])
self.sd.text_encoder.to(original_device_dict['text_encoder'])
self.network.train()
self.network.multiplier = start_multiplier
# self.sd.tokenizer.to(original_device_dict['tokenizer'])
def update_training_metadata(self):