8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -52,6 +52,8 @@ def get_torch_dtype(dtype_str):
return torch.float16
if dtype_str == "bf16" or dtype_str == "bfloat16":
return torch.bfloat16
if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8":
return torch.float8_e4m3fn
return dtype_str