mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-24 14:54:34 +00:00
CUTLASS 3.4.0 (#1286)
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
This commit is contained in:
@@ -176,6 +176,17 @@ def is_torch_available():
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
}
|
||||
|
||||
def possibly_add_type(torch_type_name, cutlass_type):
|
||||
# Only try adding the type if the version of torch being used supports it
|
||||
if hasattr(torch, torch_type_name):
|
||||
torch_type = getattr(torch, torch_type_name)
|
||||
_torch_to_library_dict[torch_type] = cutlass_type
|
||||
_library_to_torch_dict[cutlass_type] = torch_type
|
||||
|
||||
possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3)
|
||||
possibly_add_type("float8_e5m2", cutlass.DataType.e5m2)
|
||||
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
|
||||
Reference in New Issue
Block a user