safetensors_alt: Allow writing bfloat16 tensors

This commit is contained in:
turboderp
2026-02-10 17:47:44 +01:00
parent 6e4202eade
commit 89b841dd8a

View File

@@ -46,10 +46,14 @@ def _tensor_bytes_view(t: torch.Tensor) -> memoryview:
mv = memoryview(us).cast("B")
byte_off = int(t.storage_offset() * t.element_size())
nbytes = _tensor_nbytes(t)
return mv[byte_off : byte_off + nbytes]
return mv[byte_off: byte_off + nbytes]
except TypeError:
arr = t.numpy() # zero-copy view for CPU contiguous tensors
return memoryview(arr).cast("B")
if t.dtype is torch.bfloat16:
arr_u8 = t.view(torch.uint8).numpy() # zero-copy
return memoryview(arr_u8) # already bytes
else:
arr = t.numpy() # zero-copy view for CPU contiguous tensors
return memoryview(arr).cast("B")
def _read_exact(f: io.BufferedReader, n: int) -> bytes: