mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Various bug fixes and improvements
This commit is contained in:
@@ -34,13 +34,16 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def get_torch_dtype(dtype_str):
|
||||
# if it is a torch dtype, return it
|
||||
if isinstance(dtype_str, torch.dtype):
|
||||
return dtype_str
|
||||
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
|
||||
return torch.float
|
||||
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
|
||||
return torch.float16
|
||||
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
return dtype_str
|
||||
|
||||
|
||||
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user