Files
vaiola/pythonapp/Libs/getpytorch.py
Bacruru Sakaguchi 9e5e214944 initial commit
2025-09-12 17:10:13 +07:00

55 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
from dataclasses import dataclass
from .pip_api import pip_api
@dataclass
class PyTorchInfo:
"""Датакласс для хранения информации о версиях PyTorch компонентов"""
torch: List[str]
torchvision: List[str]
torchaudio: List[str]
class getpytorch:
"""Класс для получения версий компонентов PyTorch"""
BASE_URL = "https://download.pytorch.org/whl/"
def __init__(self, base_url: str = None):
self.base_url = base_url or self.BASE_URL
def get_versions(self, api: str) -> PyTorchInfo:
"""Получает версии всех компонентов PyTorch для указанного API"""
base_url = f"{self.base_url.rstrip('/')}/{api}"
return PyTorchInfo(
torch=self.get_torch_versions(api),
torchvision=self.get_torchvision_versions(api),
torchaudio=self.get_torchaudio_versions(api),
)
def get_torch_versions(self, api: str) -> List[str]:
"""Получает версии torch"""
return pip_api.get_pkg_versions('torch', f"{self.base_url.rstrip('/')}/{api}")
def get_torchvision_versions(self, api: str) -> List[str]:
"""Получает версии torchvision"""
return pip_api.get_pkg_versions('torchvision', f"{self.base_url.rstrip('/')}/{api}")
def get_torchaudio_versions(self, api: str) -> List[str]:
"""Получает версии torchaudio"""
return pip_api.get_pkg_versions('torchaudio', f"{self.base_url.rstrip('/')}/{api}")
if __name__ == "__main__":
# Пример использования
pytorch = getpytorch()
api = input("API version (cu121): ") or "cu121"
versions = pytorch.get_versions(api)
print(f"Все версии PyTorch: {versions.torch}")
print(f"Все версии torchvision: {versions.torchvision}")
print(f"Все версии torchaudio: {versions.torchaudio}")