Handle multi control inputs for control lora training

This commit is contained in:
Jaret Burkett
2025-03-23 07:37:08 -06:00
parent ccb66c748f
commit f10937e6da
7 changed files with 446 additions and 75 deletions

View File

@@ -241,6 +241,9 @@ class AdapterConfig:
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
else:
self.lora_config = None
self.num_control_images: int = kwargs.get('num_control_images', 1)
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
class EmbeddingConfig:
@@ -710,7 +713,7 @@ class DatasetConfig:
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
@@ -833,6 +836,7 @@ class GenerateImageConfig:
logger: Optional[EmptyLogger] = None,
num_frames: int = 1,
fps: int = 15,
ctrl_idx: int = 0
):
self.width: int = width
self.height: int = height
@@ -863,6 +867,7 @@ class GenerateImageConfig:
self.extra_values = extra_values if extra_values is not None else []
self.num_frames = num_frames
self.fps = fps
self.ctrl_idx = ctrl_idx
# prompt string will override any settings above
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
self.num_frames = int(content)
elif flag == 'fps':
self.fps = int(content)
elif flag == 'ctrl_idx':
self.ctrl_idx = int(content)
def post_process_embeddings(
self,