Various bug fixes and improvements

This commit is contained in:
Jaret Burkett
2023-08-12 05:59:50 -06:00
parent 67dfd9ced0
commit 379992d89e
5 changed files with 180 additions and 93 deletions

View File

@@ -34,13 +34,16 @@ SCHEDLER_SCHEDULE = "scaled_linear"
def get_torch_dtype(dtype_str):
# if it is a torch dtype, return it
if isinstance(dtype_str, torch.dtype):
return dtype_str
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
return torch.float
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
return torch.float16
if dtype_str == "bf16" or dtype_str == "bfloat16":
return torch.bfloat16
return None
return dtype_str
def replace_filewords_prompt(prompt, args: argparse.Namespace):