55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
from typing import List
|
||
from dataclasses import dataclass
|
||
from .pip_api import pip_api
|
||
|
||
|
||
|
||
@dataclass
|
||
class PyTorchInfo:
|
||
"""Датакласс для хранения информации о версиях PyTorch компонентов"""
|
||
torch: List[str]
|
||
torchvision: List[str]
|
||
torchaudio: List[str]
|
||
|
||
|
||
class getpytorch:
|
||
"""Класс для получения версий компонентов PyTorch"""
|
||
BASE_URL = "https://download.pytorch.org/whl/"
|
||
|
||
def __init__(self, base_url: str = None):
|
||
self.base_url = base_url or self.BASE_URL
|
||
|
||
def get_versions(self, api: str) -> PyTorchInfo:
|
||
"""Получает версии всех компонентов PyTorch для указанного API"""
|
||
base_url = f"{self.base_url.rstrip('/')}/{api}"
|
||
|
||
return PyTorchInfo(
|
||
torch=self.get_torch_versions(api),
|
||
torchvision=self.get_torchvision_versions(api),
|
||
torchaudio=self.get_torchaudio_versions(api),
|
||
)
|
||
|
||
def get_torch_versions(self, api: str) -> List[str]:
|
||
"""Получает версии torch"""
|
||
return pip_api.get_pkg_versions('torch', f"{self.base_url.rstrip('/')}/{api}")
|
||
|
||
def get_torchvision_versions(self, api: str) -> List[str]:
|
||
"""Получает версии torchvision"""
|
||
return pip_api.get_pkg_versions('torchvision', f"{self.base_url.rstrip('/')}/{api}")
|
||
|
||
def get_torchaudio_versions(self, api: str) -> List[str]:
|
||
"""Получает версии torchaudio"""
|
||
return pip_api.get_pkg_versions('torchaudio', f"{self.base_url.rstrip('/')}/{api}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Пример использования
|
||
pytorch = getpytorch()
|
||
|
||
api = input("API version (cu121): ") or "cu121"
|
||
versions = pytorch.get_versions(api)
|
||
|
||
print(f"Все версии PyTorch: {versions.torch}")
|
||
print(f"Все версии torchvision: {versions.torchvision}")
|
||
print(f"Все версии torchaudio: {versions.torchaudio}")
|