mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
safetensors_alt: Allow writing bfloat16 tensors
This commit is contained in:
@@ -46,10 +46,14 @@ def _tensor_bytes_view(t: torch.Tensor) -> memoryview:
|
||||
mv = memoryview(us).cast("B")
|
||||
byte_off = int(t.storage_offset() * t.element_size())
|
||||
nbytes = _tensor_nbytes(t)
|
||||
return mv[byte_off : byte_off + nbytes]
|
||||
return mv[byte_off: byte_off + nbytes]
|
||||
except TypeError:
|
||||
arr = t.numpy() # zero-copy view for CPU contiguous tensors
|
||||
return memoryview(arr).cast("B")
|
||||
if t.dtype is torch.bfloat16:
|
||||
arr_u8 = t.view(torch.uint8).numpy() # zero-copy
|
||||
return memoryview(arr_u8) # already bytes
|
||||
else:
|
||||
arr = t.numpy() # zero-copy view for CPU contiguous tensors
|
||||
return memoryview(arr).cast("B")
|
||||
|
||||
|
||||
def _read_exact(f: io.BufferedReader, n: int) -> bytes:
|
||||
|
||||
Reference in New Issue
Block a user