mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Handle multi control inputs for control lora training
This commit is contained in:
@@ -241,6 +241,9 @@ class AdapterConfig:
|
||||
self.lora_config: NetworkConfig = NetworkConfig(**lora_config)
|
||||
else:
|
||||
self.lora_config = None
|
||||
self.num_control_images: int = kwargs.get('num_control_images', 1)
|
||||
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
|
||||
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -710,7 +713,7 @@ class DatasetConfig:
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
|
||||
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
||||
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
@@ -833,6 +836,7 @@ class GenerateImageConfig:
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
):
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
@@ -863,6 +867,7 @@ class GenerateImageConfig:
|
||||
self.extra_values = extra_values if extra_values is not None else []
|
||||
self.num_frames = num_frames
|
||||
self.fps = fps
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
|
||||
# prompt string will override any settings above
|
||||
@@ -1056,6 +1061,8 @@ class GenerateImageConfig:
|
||||
self.num_frames = int(content)
|
||||
elif flag == 'fps':
|
||||
self.fps = int(content)
|
||||
elif flag == 'ctrl_idx':
|
||||
self.ctrl_idx = int(content)
|
||||
|
||||
def post_process_embeddings(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user