From c9f982af8394e54fc31924d3bd0f578919d3da5a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 6 Oct 2025 13:46:57 -0600 Subject: [PATCH] Add support for using quantized models with ramtorch --- toolkit/memory_management/manager_modules.py | 304 +++++++++++++++---- 1 file changed, 240 insertions(+), 64 deletions(-) diff --git a/toolkit/memory_management/manager_modules.py b/toolkit/memory_management/manager_modules.py index 4898dfc3..02852918 100644 --- a/toolkit/memory_management/manager_modules.py +++ b/toolkit/memory_management/manager_modules.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from typing import TYPE_CHECKING, Optional, Tuple +from torch.overrides import has_torch_function_unary # (ADD) torchao detection if TYPE_CHECKING: from .manager import MemoryManager @@ -55,11 +56,53 @@ def _get_device_state(device: torch.device): return _DEVICE_STATE[device] +# (ADD) detect torchao wrapper tensors +def _is_ao_quantized_tensor(t: Optional[torch.Tensor]) -> bool: + if t is None: + return False + try: + if has_torch_function_unary(t): + return t.__class__.__module__.startswith("torchao.") + except Exception: + pass + for attr in ( + "_scale", + "_scales", + "_zero_point", + "_zp", + "_block_size", + "_group_size", + "_pack_dim", + ): + if hasattr(t, attr): + return True + return False + + +def _is_quantized_tensor(t: Optional[torch.Tensor]) -> bool: + if t is None: + return False + # torch quantized tensors + try: + if torch.is_quantized(t): # type: ignore[attr-defined] + return True + except Exception: + pass + # (ADD) torchao quantized wrappers + if _is_ao_quantized_tensor(t): + return True + # packed/int formats (weight-only) + return not t.dtype.is_floating_point + + def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if t is None: return None if t.device.type != "cpu": t = t.to("cpu", copy=True) + # Don't attempt to pin quantized tensors; many backends don't support it + if _is_quantized_tensor(t): + return t if torch.cuda.is_available(): try: t = t.pin_memory() @@ -86,8 +129,35 @@ def _move_params_to_cpu_and_pin(module: nn.Module): class _BouncingLinearFn(torch.autograd.Function): @staticmethod def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device): + # choose compute dtype to match activations + target_dtype = ( + x.dtype + if x.dtype in (torch.bfloat16, torch.float16, torch.float32) + else torch.bfloat16 + ) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_linear_weight(cpu_w, dev): + if _is_quantized_tensor(cpu_w): + # move quantized wrapper to GPU -> dequantize on GPU -> cast on GPU + w_q_gpu = cpu_w.to(dev, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w_gpu = cpu_w.to(dev, non_blocking=True) + return w_gpu + if device.type != "cuda": - out = F.linear(x.to("cpu"), weight_cpu, bias_cpu) + out = F.linear( + x.to("cpu"), + _materialize_linear_weight(weight_cpu, torch.device("cpu")), + bias_cpu, + ) ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) ctx.device = torch.device("cpu") return out.to(x.device) @@ -101,7 +171,7 @@ class _BouncingLinearFn(torch.autograd.Function): with torch.cuda.stream(ts): ts.wait_event(ev_cu_s) - w_bufs[idx] = weight_cpu.to(device, non_blocking=True) + w_bufs[idx] = _materialize_linear_weight(weight_cpu, device) b_bufs[idx] = ( bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None ) @@ -114,21 +184,39 @@ class _BouncingLinearFn(torch.autograd.Function): ctx.save_for_backward(x, weight_cpu, bias_cpu) ctx.device = device + ctx.target_dtype = target_dtype return out @staticmethod def backward(ctx, grad_out): x, weight_cpu, bias_cpu = ctx.saved_tensors device = ctx.device + target_dtype = getattr(ctx, "target_dtype", grad_out.dtype) if device.type != "cuda": go_cpu = grad_out.to("cpu") x_cpu = x.to("cpu") - grad_input = go_cpu @ weight_cpu - grad_weight = go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2) + w_mat = ( + weight_cpu.dequantize() + if _is_quantized_tensor(weight_cpu) + else weight_cpu + ) + if w_mat.dtype != target_dtype and target_dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ): + w_mat = w_mat.to(target_dtype) + grad_input = go_cpu @ w_mat + grad_weight = ( + go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2) + if getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + else None + ) grad_bias = ( go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1))) - if bias_cpu is not None + if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False)) else None ) return grad_input.to(grad_out.device), grad_weight, grad_bias, None @@ -148,45 +236,62 @@ class _BouncingLinearFn(torch.autograd.Function): idx = state["backward_clk"] - # Stage weights onto device (transfer stream), ping-pong to avoid races + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_for_bwd(cpu_w): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(device, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w = cpu_w.to(device, non_blocking=True) + return w + with torch.cuda.stream(transfer_stream): transfer_stream.wait_event(ev_cu_b_start) - w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) + w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu) state["backward_clk"] ^= 1 ev_tx_b.record() - # Compute stream waits for weights to arrive, then start compute torch.cuda.current_stream().wait_event(ev_tx_b) ev_cu_b_start.record() - # 1) Compute grad_input using the freshly transferred weights - grad_input = grad_out @ w_bwd_buffers[idx] + # grad wrt input (GPU) + grad_input = grad_out.to(dtype=target_dtype) @ w_bwd_buffers[idx] - # 2) Ensure previous grad-to-CPU transfer that used this slot finished + # ensure previous grad-to-CPU transfer that used this slot finished torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) - # 3) Compute weight/bias grads on GPU into staging buffers - w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2) - if bias_cpu is not None: + # compute grads if float masters exist + grad_weight = None + grad_bias = None + if ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ): + w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2) + if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False): reduce_dims = tuple(range(grad_out.ndim - 1)) b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims) - # Mark end of GPU compute ev_cu_b_finish.record() - # 4) Launch non-blocking H2D->CPU transfers on a separate grad stream (full-duplex) with torch.cuda.stream(transfer_grad_stream): transfer_grad_stream.wait_event(ev_cu_b_finish) - grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) - grad_bias = ( - b_grad_buffers[idx].to("cpu", non_blocking=True) - if bias_cpu is not None - else None - ) - # signal that this slot's CPU transfer is complete (safe for next reuse) + if ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ): + grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) + if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False): + grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True) state["transfer_weight_backward_finished_event"].record() - return grad_input, grad_weight, grad_bias, None + return grad_input.to(dtype=grad_out.dtype), grad_weight, grad_bias, None class _BouncingConv2dFn(torch.autograd.Function): @@ -202,12 +307,39 @@ class _BouncingConv2dFn(torch.autograd.Function): dilation: Tuple[int, int], groups: int, ): + target_dtype = ( + x.dtype + if x.dtype in (torch.bfloat16, torch.float16, torch.float32) + else torch.bfloat16 + ) + + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_conv_weight(cpu_w, dev): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(dev, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w_gpu = cpu_w.to(dev, non_blocking=True) + return w_gpu + if device.type != "cuda": out = F.conv2d( - x.to("cpu"), weight_cpu, bias_cpu, stride, padding, dilation, groups + x.to("cpu"), + _materialize_conv_weight(weight_cpu, torch.device("cpu")), + bias_cpu, + stride, + padding, + dilation, + groups, ) ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu) - ctx.meta = ("cpu", stride, padding, dilation, groups) + ctx.meta = ("cpu", stride, padding, dilation, groups, target_dtype) return out.to(x.device) state = _get_device_state(device) @@ -219,7 +351,7 @@ class _BouncingConv2dFn(torch.autograd.Function): with torch.cuda.stream(ts): ts.wait_event(ev_cu_s) - w_bufs[idx] = weight_cpu.to(device, non_blocking=True) + w_bufs[idx] = _materialize_conv_weight(weight_cpu, device) b_bufs[idx] = ( bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None ) @@ -231,22 +363,30 @@ class _BouncingConv2dFn(torch.autograd.Function): out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups) ctx.save_for_backward(x, weight_cpu, bias_cpu) - ctx.meta = (device, stride, padding, dilation, groups) + ctx.meta = (device, stride, padding, dilation, groups, target_dtype) return out @staticmethod def backward(ctx, grad_out): x, weight_cpu, bias_cpu = ctx.saved_tensors - meta = ctx.meta - device, stride, padding, dilation, groups = meta + device, stride, padding, dilation, groups, target_dtype = ctx.meta if ( isinstance(device, torch.device) and device.type != "cuda" ) or device == "cpu": - # CPU grads go = grad_out.to("cpu") x_cpu = x.to("cpu") - w_cpu = weight_cpu + w_cpu = ( + weight_cpu.dequantize() + if _is_quantized_tensor(weight_cpu) + else weight_cpu + ) + if w_cpu.dtype != target_dtype and target_dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ): + w_cpu = w_cpu.to(target_dtype) from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore grad_input = conv2d_input( @@ -258,16 +398,25 @@ class _BouncingConv2dFn(torch.autograd.Function): dilation=dilation, groups=groups, ) - grad_weight = conv2d_weight( - x_cpu, - w_cpu.shape, - go, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, + grad_weight = ( + conv2d_weight( + x_cpu, + w_cpu.shape, + go, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + if getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + else None + ) + grad_bias = ( + go.sum(dim=(0, 2, 3)) + if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False)) + else None ) - grad_bias = go.sum(dim=(0, 2, 3)) if bias_cpu is not None else None return ( grad_input.to(grad_out.device), grad_weight, @@ -279,12 +428,10 @@ class _BouncingConv2dFn(torch.autograd.Function): None, ) - # CUDA path (full-duplex) state = _get_device_state(device) transfer_stream = state["transfer_stream"] transfer_grad_stream = state["transfer_grad_stream"] - # device-side buffers w_bwd_buffers = state["w_bwd_buffers"] w_grad_buffers = state["w_grad_buffers"] b_grad_buffers = state["b_grad_buffers"] @@ -296,23 +443,37 @@ class _BouncingConv2dFn(torch.autograd.Function): idx = state["backward_clk"] + # GPU-side dequant/cast for quantized; float path unchanged + def _materialize_for_bwd(cpu_w): + if _is_quantized_tensor(cpu_w): + w_q_gpu = cpu_w.to(device, non_blocking=True) + try: + w_fp_gpu = w_q_gpu.dequantize() + except Exception: + w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True) + if w_fp_gpu.dtype != target_dtype: + w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True) + return w_fp_gpu + # float path (preserve original behavior: NO dtype cast) + w = cpu_w.to(device, non_blocking=True) + return w + # Stage weights for input-grad compute with torch.cuda.stream(transfer_stream): transfer_stream.wait_event(ev_cu_b_start) - w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True) + w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu) state["backward_clk"] ^= 1 ev_tx_b.record() torch.cuda.current_stream().wait_event(ev_tx_b) ev_cu_b_start.record() - # grad wrt input on GPU with streamed weights from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore grad_input = conv2d_input( x.shape, w_bwd_buffers[idx], - grad_out, + grad_out.to(dtype=target_dtype), stride=stride, padding=padding, dilation=dilation, @@ -323,33 +484,48 @@ class _BouncingConv2dFn(torch.autograd.Function): torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done) # Compute heavy grads on GPU into staging buffers - w_grad_buffers[idx] = conv2d_weight( - x, - weight_cpu.shape, - grad_out, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - if bias_cpu is not None: + grad_weight = None + grad_bias = None + if ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ): + w_grad_buffers[idx] = conv2d_weight( + x, + weight_cpu.shape, + grad_out, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False): b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3)) - # Mark end of GPU math ev_cu_b_finish.record() # Launch CPU copies on the dedicated grad stream (overlaps with next H2D) with torch.cuda.stream(transfer_grad_stream): transfer_grad_stream.wait_event(ev_cu_b_finish) - grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) - grad_bias = ( - b_grad_buffers[idx].to("cpu", non_blocking=True) - if bias_cpu is not None - else None - ) + if ( + getattr(weight_cpu, "requires_grad", False) + and weight_cpu.dtype.is_floating_point + ): + grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True) + if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False): + grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True) state["transfer_weight_backward_finished_event"].record() - return grad_input, grad_weight, grad_bias, None, None, None, None, None + return ( + grad_input.to(dtype=grad_out.dtype), + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) class BaseLayerMemoryManager: