Files
vaiola/modules/civit/client.py

177 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from dataclasses import dataclass
from pathlib import Path
import requests
import time
from typing import Optional
from requests import Session
from pythonapp.Libs.ConfigDataClass import Config
@dataclass
class ClientConfig(Config):
api_key: str = ''
base_url: str = 'https://civitai.com/'
class Client:
def __init__(self, path, api_key: str):
self.path = path
os.makedirs(self.path, exist_ok=True)
self.config = ClientConfig(str(Path(self.path) / 'config.json'), autosave=True)
if self.config.api_key == '': self.config.api_key = api_key
self.config.save()
self._headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.config.api_key}'}
self.session = Session()
self.session.headers.update(self._headers)
pass
def enroll_key(self, key: str):
self.config.api_key = key
self.config.save()
@staticmethod
def build_query_string(params):
"""Build query string from dictionary of parameters
Args:
params (dict): Dictionary of parameters
Returns:
str: Query string in format '?param1=value1&param2=value2'
"""
if not params:
return ""
filtered_params = {k: v for k, v in params.items() if v is not None}
if not filtered_params:
return ""
query_parts = []
for key, value in filtered_params.items():
query_parts.append(f"{key}={value}")
return "?" + "&".join(query_parts)
def make_get_request(self, url: str, max_retries: int = 10, delay: float = 3.0,
timeout: int = 300, **kwargs) -> Optional[requests.Response]:
"""
Выполняет GET запрос с обработкой ошибок и повторными попытками
Args:
url (str): URL для запроса
max_retries (int): Максимальное количество попыток (по умолчанию 3)
delay (float): Задержка между попытками в секундах (по умолчанию 1.0)
timeout (int): Таймаут запроса в секундах (по умолчанию 30)
**kwargs: Дополнительные аргументы для requests.get()
Returns:
Optional[requests.Response]: Объект Response или None в случае ошибки
"""
session = self.session
for attempt in range(max_retries + 1):
try:
response = session.get(url, timeout=timeout, **kwargs)
response.raise_for_status() # Вызовет исключение для HTTP ошибок
return response
except (requests.exceptions.RequestException,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout) as e:
if attempt == max_retries:
print(f"Не удалось выполнить запрос после {max_retries} попыток: {e}")
return None
print(f"Попытка {attempt + 1} не удалась: {e}. Повтор через {delay} секунд...")
time.sleep(delay)
return None
def get_creators_tags_raw(self, entity: str, page=None, limit = 200, query = None):
if not limit: limit = 200
if entity not in {'creators', 'tags'}: raise ValueError('Not in types')
response = self.make_get_request(
url = self.config.base_url + 'api/v1/' + entity + self.build_query_string(
{'page': page, 'limit': limit, 'query': query}
)
)
return response.json()
def get_creators_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('creators', page, limit, query)
def get_tags_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('tags', page, limit, query)
def get_model_raw(self, model_id: int):
try:
return self.make_get_request(f'{self.config.base_url}/api/v1/models/{model_id}').json()
except requests.exceptions.HTTPError as e:
print(e)
return {}
def download_file(self, url: str, path: str, chill_time: int = 3, max_retries: int = 3):
"""
Загружает файл по URL в указанный путь
Args:
url (str): URL файла для загрузки
path (str): Путь для сохранения файла
chill_time (int): Время ожидания в секундах при ошибке (по умолчанию 3)
max_retries (int): Максимальное количество попыток загрузки (по умолчанию 3)
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
for attempt in range(max_retries):
try:
# Создаем запрос с прогресс-баром
response = self.session.get(url, stream=True, timeout=30)
response.raise_for_status()
# Получаем размер файла
total_size = int(response.headers.get('content-length', 0))
# Загружаем файл по частями
with open(path, 'wb') as file:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
if chunk:
file.write(chunk)
downloaded += len(chunk)
# Отображаем прогресс
if total_size > 0:
progress = (downloaded / total_size) * 100
print(f"\rDownloading: {progress:.1f}% ({downloaded}/{total_size} bytes)", end='',
flush=True)
print(f"\nFile downloaded successfully to {path}")
return True
except requests.exceptions.RequestException as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
print(f"Waiting {chill_time} seconds before retry...")
time.sleep(chill_time)
else:
print(f"Failed to download file after {max_retries} attempts")
return False
except IOError as e:
print(f"IO Error while saving file: {e}")
return False
except Exception as e:
print(f"Unexpected error: {e}")
return False
return False