diff --git a/spaces.py b/spaces.py index ea0fd7c7..cfcbe393 100644 --- a/spaces.py +++ b/spaces.py @@ -106,6 +106,9 @@ class GPUObject: def GPU(gpu_objects=None, manual_load=False): gpu_objects = gpu_objects or [] + if not isinstance(gpu_objects, (list, tuple)): + gpu_objects = [gpu_objects] + def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs):