Files
vaiola/modules/civit/client.py
bsakaguchi 817283034a Add civit module
Add civit.fetch function
2025-09-19 18:12:57 +07:00

112 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from dataclasses import dataclass
from pathlib import Path
import requests
import time
from typing import Optional
from requests import Session
from modelspace.Repository import Repository
from pythonapp.Libs.ConfigDataClass import Config
@dataclass
class ClientConfig(Config):
api_key: str = ''
base_url: str = 'https://civitai.com/'
class Client:
def __init__(self, path, api_key: str):
self.path = path
os.makedirs(self.path, exist_ok=True)
self.config = ClientConfig(str(Path(self.path) / 'config.json'), autosave=True)
if self.config.api_key == '': self.config.api_key = api_key
self.config.save()
self._headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.config.api_key}'}
self.session = Session()
self.session.headers.update(self._headers)
pass
def enroll_key(self, key: str):
self.config.api_key = key
self.config.save()
@staticmethod
def build_query_string(params):
"""Build query string from dictionary of parameters
Args:
params (dict): Dictionary of parameters
Returns:
str: Query string in format '?param1=value1&param2=value2'
"""
if not params:
return ""
filtered_params = {k: v for k, v in params.items() if v is not None}
if not filtered_params:
return ""
query_parts = []
for key, value in filtered_params.items():
query_parts.append(f"{key}={value}")
return "?" + "&".join(query_parts)
def make_get_request(self, url: str, max_retries: int = 10, delay: float = 3.0,
timeout: int = 300, **kwargs) -> Optional[requests.Response]:
"""
Выполняет GET запрос с обработкой ошибок и повторными попытками
Args:
url (str): URL для запроса
max_retries (int): Максимальное количество попыток (по умолчанию 3)
delay (float): Задержка между попытками в секундах (по умолчанию 1.0)
timeout (int): Таймаут запроса в секундах (по умолчанию 30)
**kwargs: Дополнительные аргументы для requests.get()
Returns:
Optional[requests.Response]: Объект Response или None в случае ошибки
"""
session = self.session
for attempt in range(max_retries + 1):
try:
response = session.get(url, timeout=timeout, **kwargs)
response.raise_for_status() # Вызовет исключение для HTTP ошибок
return response
except (requests.exceptions.RequestException,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout) as e:
if attempt == max_retries:
print(f"Не удалось выполнить запрос после {max_retries} попыток: {e}")
return None
print(f"Попытка {attempt + 1} не удалась: {e}. Повтор через {delay} секунд...")
time.sleep(delay)
return None
def get_creators_tags_raw(self, entity: str, page=None, limit = 200, query = None):
if not limit: limit = 200
if entity not in {'creators', 'tags'}: raise ValueError('Not in types')
response = self.make_get_request(
url = self.config.base_url + 'api/v1/' + entity + self.build_query_string(
{'page': page, 'limit': limit, 'query': query}
)
)
return response.json()
def get_creators_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('creators', page, limit, query)
def get_tags_raw(self, page=None, limit = 200, query = None): return self.get_creators_tags_raw('tags', page, limit, query)