Add support for using quantized models with ramtorch

This commit is contained in:
Jaret Burkett
2025-10-06 13:46:57 -06:00
parent dc1cc3e78a
commit c9f982af83

View File

@@ -10,6 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from torch.overrides import has_torch_function_unary # (ADD) torchao detection
if TYPE_CHECKING: if TYPE_CHECKING:
from .manager import MemoryManager from .manager import MemoryManager
@@ -55,11 +56,53 @@ def _get_device_state(device: torch.device):
return _DEVICE_STATE[device] return _DEVICE_STATE[device]
# (ADD) detect torchao wrapper tensors
def _is_ao_quantized_tensor(t: Optional[torch.Tensor]) -> bool:
if t is None:
return False
try:
if has_torch_function_unary(t):
return t.__class__.__module__.startswith("torchao.")
except Exception:
pass
for attr in (
"_scale",
"_scales",
"_zero_point",
"_zp",
"_block_size",
"_group_size",
"_pack_dim",
):
if hasattr(t, attr):
return True
return False
def _is_quantized_tensor(t: Optional[torch.Tensor]) -> bool:
if t is None:
return False
# torch quantized tensors
try:
if torch.is_quantized(t): # type: ignore[attr-defined]
return True
except Exception:
pass
# (ADD) torchao quantized wrappers
if _is_ao_quantized_tensor(t):
return True
# packed/int formats (weight-only)
return not t.dtype.is_floating_point
def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if t is None: if t is None:
return None return None
if t.device.type != "cpu": if t.device.type != "cpu":
t = t.to("cpu", copy=True) t = t.to("cpu", copy=True)
# Don't attempt to pin quantized tensors; many backends don't support it
if _is_quantized_tensor(t):
return t
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
t = t.pin_memory() t = t.pin_memory()
@@ -86,8 +129,35 @@ def _move_params_to_cpu_and_pin(module: nn.Module):
class _BouncingLinearFn(torch.autograd.Function): class _BouncingLinearFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device): def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device):
# choose compute dtype to match activations
target_dtype = (
x.dtype
if x.dtype in (torch.bfloat16, torch.float16, torch.float32)
else torch.bfloat16
)
# GPU-side dequant/cast for quantized; float path unchanged
def _materialize_linear_weight(cpu_w, dev):
if _is_quantized_tensor(cpu_w):
# move quantized wrapper to GPU -> dequantize on GPU -> cast on GPU
w_q_gpu = cpu_w.to(dev, non_blocking=True)
try:
w_fp_gpu = w_q_gpu.dequantize()
except Exception:
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
if w_fp_gpu.dtype != target_dtype:
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
return w_fp_gpu
# float path (preserve original behavior: NO dtype cast)
w_gpu = cpu_w.to(dev, non_blocking=True)
return w_gpu
if device.type != "cuda": if device.type != "cuda":
out = F.linear(x.to("cpu"), weight_cpu, bias_cpu) out = F.linear(
x.to("cpu"),
_materialize_linear_weight(weight_cpu, torch.device("cpu")),
bias_cpu,
)
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
ctx.device = torch.device("cpu") ctx.device = torch.device("cpu")
return out.to(x.device) return out.to(x.device)
@@ -101,7 +171,7 @@ class _BouncingLinearFn(torch.autograd.Function):
with torch.cuda.stream(ts): with torch.cuda.stream(ts):
ts.wait_event(ev_cu_s) ts.wait_event(ev_cu_s)
w_bufs[idx] = weight_cpu.to(device, non_blocking=True) w_bufs[idx] = _materialize_linear_weight(weight_cpu, device)
b_bufs[idx] = ( b_bufs[idx] = (
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
) )
@@ -114,21 +184,39 @@ class _BouncingLinearFn(torch.autograd.Function):
ctx.save_for_backward(x, weight_cpu, bias_cpu) ctx.save_for_backward(x, weight_cpu, bias_cpu)
ctx.device = device ctx.device = device
ctx.target_dtype = target_dtype
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
x, weight_cpu, bias_cpu = ctx.saved_tensors x, weight_cpu, bias_cpu = ctx.saved_tensors
device = ctx.device device = ctx.device
target_dtype = getattr(ctx, "target_dtype", grad_out.dtype)
if device.type != "cuda": if device.type != "cuda":
go_cpu = grad_out.to("cpu") go_cpu = grad_out.to("cpu")
x_cpu = x.to("cpu") x_cpu = x.to("cpu")
grad_input = go_cpu @ weight_cpu w_mat = (
grad_weight = go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2) weight_cpu.dequantize()
if _is_quantized_tensor(weight_cpu)
else weight_cpu
)
if w_mat.dtype != target_dtype and target_dtype in (
torch.bfloat16,
torch.float16,
torch.float32,
):
w_mat = w_mat.to(target_dtype)
grad_input = go_cpu @ w_mat
grad_weight = (
go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2)
if getattr(weight_cpu, "requires_grad", False)
and weight_cpu.dtype.is_floating_point
else None
)
grad_bias = ( grad_bias = (
go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1))) go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1)))
if bias_cpu is not None if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False))
else None else None
) )
return grad_input.to(grad_out.device), grad_weight, grad_bias, None return grad_input.to(grad_out.device), grad_weight, grad_bias, None
@@ -148,45 +236,62 @@ class _BouncingLinearFn(torch.autograd.Function):
idx = state["backward_clk"] idx = state["backward_clk"]
# Stage weights onto device (transfer stream), ping-pong to avoid races # GPU-side dequant/cast for quantized; float path unchanged
def _materialize_for_bwd(cpu_w):
if _is_quantized_tensor(cpu_w):
w_q_gpu = cpu_w.to(device, non_blocking=True)
try:
w_fp_gpu = w_q_gpu.dequantize()
except Exception:
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
if w_fp_gpu.dtype != target_dtype:
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
return w_fp_gpu
# float path (preserve original behavior: NO dtype cast)
w = cpu_w.to(device, non_blocking=True)
return w
with torch.cuda.stream(transfer_stream): with torch.cuda.stream(transfer_stream):
transfer_stream.wait_event(ev_cu_b_start) transfer_stream.wait_event(ev_cu_b_start)
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu)
state["backward_clk"] ^= 1 state["backward_clk"] ^= 1
ev_tx_b.record() ev_tx_b.record()
# Compute stream waits for weights to arrive, then start compute
torch.cuda.current_stream().wait_event(ev_tx_b) torch.cuda.current_stream().wait_event(ev_tx_b)
ev_cu_b_start.record() ev_cu_b_start.record()
# 1) Compute grad_input using the freshly transferred weights # grad wrt input (GPU)
grad_input = grad_out @ w_bwd_buffers[idx] grad_input = grad_out.to(dtype=target_dtype) @ w_bwd_buffers[idx]
# 2) Ensure previous grad-to-CPU transfer that used this slot finished # ensure previous grad-to-CPU transfer that used this slot finished
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
# 3) Compute weight/bias grads on GPU into staging buffers # compute grads if float masters exist
w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2) grad_weight = None
if bias_cpu is not None: grad_bias = None
if (
getattr(weight_cpu, "requires_grad", False)
and weight_cpu.dtype.is_floating_point
):
w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2)
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
reduce_dims = tuple(range(grad_out.ndim - 1)) reduce_dims = tuple(range(grad_out.ndim - 1))
b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims) b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims)
# Mark end of GPU compute
ev_cu_b_finish.record() ev_cu_b_finish.record()
# 4) Launch non-blocking H2D->CPU transfers on a separate grad stream (full-duplex)
with torch.cuda.stream(transfer_grad_stream): with torch.cuda.stream(transfer_grad_stream):
transfer_grad_stream.wait_event(ev_cu_b_finish) transfer_grad_stream.wait_event(ev_cu_b_finish)
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) if (
grad_bias = ( getattr(weight_cpu, "requires_grad", False)
b_grad_buffers[idx].to("cpu", non_blocking=True) and weight_cpu.dtype.is_floating_point
if bias_cpu is not None ):
else None grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
) if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
# signal that this slot's CPU transfer is complete (safe for next reuse) grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True)
state["transfer_weight_backward_finished_event"].record() state["transfer_weight_backward_finished_event"].record()
return grad_input, grad_weight, grad_bias, None return grad_input.to(dtype=grad_out.dtype), grad_weight, grad_bias, None
class _BouncingConv2dFn(torch.autograd.Function): class _BouncingConv2dFn(torch.autograd.Function):
@@ -202,12 +307,39 @@ class _BouncingConv2dFn(torch.autograd.Function):
dilation: Tuple[int, int], dilation: Tuple[int, int],
groups: int, groups: int,
): ):
target_dtype = (
x.dtype
if x.dtype in (torch.bfloat16, torch.float16, torch.float32)
else torch.bfloat16
)
# GPU-side dequant/cast for quantized; float path unchanged
def _materialize_conv_weight(cpu_w, dev):
if _is_quantized_tensor(cpu_w):
w_q_gpu = cpu_w.to(dev, non_blocking=True)
try:
w_fp_gpu = w_q_gpu.dequantize()
except Exception:
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
if w_fp_gpu.dtype != target_dtype:
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
return w_fp_gpu
# float path (preserve original behavior: NO dtype cast)
w_gpu = cpu_w.to(dev, non_blocking=True)
return w_gpu
if device.type != "cuda": if device.type != "cuda":
out = F.conv2d( out = F.conv2d(
x.to("cpu"), weight_cpu, bias_cpu, stride, padding, dilation, groups x.to("cpu"),
_materialize_conv_weight(weight_cpu, torch.device("cpu")),
bias_cpu,
stride,
padding,
dilation,
groups,
) )
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
ctx.meta = ("cpu", stride, padding, dilation, groups) ctx.meta = ("cpu", stride, padding, dilation, groups, target_dtype)
return out.to(x.device) return out.to(x.device)
state = _get_device_state(device) state = _get_device_state(device)
@@ -219,7 +351,7 @@ class _BouncingConv2dFn(torch.autograd.Function):
with torch.cuda.stream(ts): with torch.cuda.stream(ts):
ts.wait_event(ev_cu_s) ts.wait_event(ev_cu_s)
w_bufs[idx] = weight_cpu.to(device, non_blocking=True) w_bufs[idx] = _materialize_conv_weight(weight_cpu, device)
b_bufs[idx] = ( b_bufs[idx] = (
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
) )
@@ -231,22 +363,30 @@ class _BouncingConv2dFn(torch.autograd.Function):
out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups) out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups)
ctx.save_for_backward(x, weight_cpu, bias_cpu) ctx.save_for_backward(x, weight_cpu, bias_cpu)
ctx.meta = (device, stride, padding, dilation, groups) ctx.meta = (device, stride, padding, dilation, groups, target_dtype)
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
x, weight_cpu, bias_cpu = ctx.saved_tensors x, weight_cpu, bias_cpu = ctx.saved_tensors
meta = ctx.meta device, stride, padding, dilation, groups, target_dtype = ctx.meta
device, stride, padding, dilation, groups = meta
if ( if (
isinstance(device, torch.device) and device.type != "cuda" isinstance(device, torch.device) and device.type != "cuda"
) or device == "cpu": ) or device == "cpu":
# CPU grads
go = grad_out.to("cpu") go = grad_out.to("cpu")
x_cpu = x.to("cpu") x_cpu = x.to("cpu")
w_cpu = weight_cpu w_cpu = (
weight_cpu.dequantize()
if _is_quantized_tensor(weight_cpu)
else weight_cpu
)
if w_cpu.dtype != target_dtype and target_dtype in (
torch.bfloat16,
torch.float16,
torch.float32,
):
w_cpu = w_cpu.to(target_dtype)
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
grad_input = conv2d_input( grad_input = conv2d_input(
@@ -258,16 +398,25 @@ class _BouncingConv2dFn(torch.autograd.Function):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
) )
grad_weight = conv2d_weight( grad_weight = (
x_cpu, conv2d_weight(
w_cpu.shape, x_cpu,
go, w_cpu.shape,
stride=stride, go,
padding=padding, stride=stride,
dilation=dilation, padding=padding,
groups=groups, dilation=dilation,
groups=groups,
)
if getattr(weight_cpu, "requires_grad", False)
and weight_cpu.dtype.is_floating_point
else None
)
grad_bias = (
go.sum(dim=(0, 2, 3))
if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False))
else None
) )
grad_bias = go.sum(dim=(0, 2, 3)) if bias_cpu is not None else None
return ( return (
grad_input.to(grad_out.device), grad_input.to(grad_out.device),
grad_weight, grad_weight,
@@ -279,12 +428,10 @@ class _BouncingConv2dFn(torch.autograd.Function):
None, None,
) )
# CUDA path (full-duplex)
state = _get_device_state(device) state = _get_device_state(device)
transfer_stream = state["transfer_stream"] transfer_stream = state["transfer_stream"]
transfer_grad_stream = state["transfer_grad_stream"] transfer_grad_stream = state["transfer_grad_stream"]
# device-side buffers
w_bwd_buffers = state["w_bwd_buffers"] w_bwd_buffers = state["w_bwd_buffers"]
w_grad_buffers = state["w_grad_buffers"] w_grad_buffers = state["w_grad_buffers"]
b_grad_buffers = state["b_grad_buffers"] b_grad_buffers = state["b_grad_buffers"]
@@ -296,23 +443,37 @@ class _BouncingConv2dFn(torch.autograd.Function):
idx = state["backward_clk"] idx = state["backward_clk"]
# GPU-side dequant/cast for quantized; float path unchanged
def _materialize_for_bwd(cpu_w):
if _is_quantized_tensor(cpu_w):
w_q_gpu = cpu_w.to(device, non_blocking=True)
try:
w_fp_gpu = w_q_gpu.dequantize()
except Exception:
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
if w_fp_gpu.dtype != target_dtype:
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
return w_fp_gpu
# float path (preserve original behavior: NO dtype cast)
w = cpu_w.to(device, non_blocking=True)
return w
# Stage weights for input-grad compute # Stage weights for input-grad compute
with torch.cuda.stream(transfer_stream): with torch.cuda.stream(transfer_stream):
transfer_stream.wait_event(ev_cu_b_start) transfer_stream.wait_event(ev_cu_b_start)
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu)
state["backward_clk"] ^= 1 state["backward_clk"] ^= 1
ev_tx_b.record() ev_tx_b.record()
torch.cuda.current_stream().wait_event(ev_tx_b) torch.cuda.current_stream().wait_event(ev_tx_b)
ev_cu_b_start.record() ev_cu_b_start.record()
# grad wrt input on GPU with streamed weights
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
grad_input = conv2d_input( grad_input = conv2d_input(
x.shape, x.shape,
w_bwd_buffers[idx], w_bwd_buffers[idx],
grad_out, grad_out.to(dtype=target_dtype),
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
@@ -323,33 +484,48 @@ class _BouncingConv2dFn(torch.autograd.Function):
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
# Compute heavy grads on GPU into staging buffers # Compute heavy grads on GPU into staging buffers
w_grad_buffers[idx] = conv2d_weight( grad_weight = None
x, grad_bias = None
weight_cpu.shape, if (
grad_out, getattr(weight_cpu, "requires_grad", False)
stride=stride, and weight_cpu.dtype.is_floating_point
padding=padding, ):
dilation=dilation, w_grad_buffers[idx] = conv2d_weight(
groups=groups, x,
) weight_cpu.shape,
if bias_cpu is not None: grad_out,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3)) b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3))
# Mark end of GPU math
ev_cu_b_finish.record() ev_cu_b_finish.record()
# Launch CPU copies on the dedicated grad stream (overlaps with next H2D) # Launch CPU copies on the dedicated grad stream (overlaps with next H2D)
with torch.cuda.stream(transfer_grad_stream): with torch.cuda.stream(transfer_grad_stream):
transfer_grad_stream.wait_event(ev_cu_b_finish) transfer_grad_stream.wait_event(ev_cu_b_finish)
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) if (
grad_bias = ( getattr(weight_cpu, "requires_grad", False)
b_grad_buffers[idx].to("cpu", non_blocking=True) and weight_cpu.dtype.is_floating_point
if bias_cpu is not None ):
else None grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
) if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True)
state["transfer_weight_backward_finished_event"].record() state["transfer_weight_backward_finished_event"].record()
return grad_input, grad_weight, grad_bias, None, None, None, None, None return (
grad_input.to(dtype=grad_out.dtype),
grad_weight,
grad_bias,
None,
None,
None,
None,
None,
)
class BaseLayerMemoryManager: class BaseLayerMemoryManager: