Gradio 4 + WebUI 1.10

This commit is contained in:
layerdiffusion
2024-07-26 08:51:34 -07:00
parent e95333c556
commit e26abf87ec
201 changed files with 7562 additions and 4834 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import torch.nn
import torch
def get_param(model) -> torch.nn.Parameter:
@@ -15,3 +16,10 @@ def get_param(model) -> torch.nn.Parameter:
return param
raise ValueError(f"No parameters found in model {model!r}")
def float64(t: torch.Tensor):
"""return torch.float64 if device is not mps or xpu, else return torch.float32"""
if t.device.type in ['mps', 'xpu']:
return torch.float32
return torch.float64