mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
132 lines
4.8 KiB
Python
132 lines
4.8 KiB
Python
import os
|
|
import shutil
|
|
from collections import OrderedDict
|
|
import gc
|
|
from typing import List
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
|
|
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
|
|
get_img_paths
|
|
from jobs.process import BaseExtensionProcess
|
|
|
|
|
|
def flush():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
class SyncFromCollection(BaseExtensionProcess):
|
|
|
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
|
super().__init__(process_id, job, config)
|
|
|
|
self.min_width = config.get('min_width', 1024)
|
|
self.min_height = config.get('min_height', 1024)
|
|
|
|
# add our min_width and min_height to each dataset config if they don't exist
|
|
for dataset_config in config.get('dataset_sync', []):
|
|
if 'min_width' not in dataset_config:
|
|
dataset_config['min_width'] = self.min_width
|
|
if 'min_height' not in dataset_config:
|
|
dataset_config['min_height'] = self.min_height
|
|
|
|
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
|
|
DatasetSyncCollectionConfig(**dataset_config)
|
|
for dataset_config in config.get('dataset_sync', [])
|
|
]
|
|
print(f"Found {len(self.dataset_configs)} dataset configs")
|
|
|
|
def move_new_images(self, root_dir: str):
|
|
raw_dir = os.path.join(root_dir, RAW_DIR)
|
|
new_dir = os.path.join(root_dir, NEW_DIR)
|
|
new_images = get_img_paths(new_dir)
|
|
|
|
for img_path in new_images:
|
|
# move to raw
|
|
new_path = os.path.join(raw_dir, os.path.basename(img_path))
|
|
shutil.move(img_path, new_path)
|
|
|
|
# remove new dir
|
|
shutil.rmtree(new_dir)
|
|
|
|
def sync_dataset(self, config: DatasetSyncCollectionConfig):
|
|
if config.host == 'unsplash':
|
|
get_images = get_unsplash_images
|
|
elif config.host == 'pexels':
|
|
get_images = get_pexels_images
|
|
else:
|
|
raise ValueError(f"Unknown host: {config.host}")
|
|
|
|
results = {
|
|
'num_downloaded': 0,
|
|
'num_skipped': 0,
|
|
'bad': 0,
|
|
'total': 0,
|
|
}
|
|
|
|
photos = get_images(config)
|
|
raw_dir = os.path.join(config.directory, RAW_DIR)
|
|
new_dir = os.path.join(config.directory, NEW_DIR)
|
|
raw_images = get_local_image_file_names(raw_dir)
|
|
new_images = get_local_image_file_names(new_dir)
|
|
|
|
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
|
|
try:
|
|
if photo.filename not in raw_images and photo.filename not in new_images:
|
|
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
|
|
results['num_downloaded'] += 1
|
|
else:
|
|
results['num_skipped'] += 1
|
|
except Exception as e:
|
|
print(f" - BAD({photo.id}): {e}")
|
|
results['bad'] += 1
|
|
continue
|
|
results['total'] += 1
|
|
|
|
return results
|
|
|
|
def print_results(self, results):
|
|
print(
|
|
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
|
|
|
|
def run(self):
|
|
super().run()
|
|
print(f"Syncing {len(self.dataset_configs)} datasets")
|
|
all_results = None
|
|
failed_datasets = []
|
|
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
|
|
try:
|
|
results = self.sync_dataset(dataset_config)
|
|
if all_results is None:
|
|
all_results = {**results}
|
|
else:
|
|
for key, value in results.items():
|
|
all_results[key] += value
|
|
|
|
self.print_results(results)
|
|
except Exception as e:
|
|
print(f" - FAILED: {e}")
|
|
if 'response' in e.__dict__:
|
|
error = f"{e.response.status_code}: {e.response.text}"
|
|
print(f" - {error}")
|
|
failed_datasets.append({'dataset': dataset_config, 'error': error})
|
|
else:
|
|
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
|
|
continue
|
|
|
|
print("Moving new images to raw")
|
|
for dataset_config in self.dataset_configs:
|
|
self.move_new_images(dataset_config.directory)
|
|
|
|
print("Done syncing datasets")
|
|
self.print_results(all_results)
|
|
|
|
if len(failed_datasets) > 0:
|
|
print(f"Failed to sync {len(failed_datasets)} datasets")
|
|
for failed in failed_datasets:
|
|
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
|
|
print(f" - ERR: {failed['error']}")
|