mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add support for using quantized models with ramtorch
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
from torch.overrides import has_torch_function_unary # (ADD) torchao detection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .manager import MemoryManager
|
from .manager import MemoryManager
|
||||||
@@ -55,11 +56,53 @@ def _get_device_state(device: torch.device):
|
|||||||
return _DEVICE_STATE[device]
|
return _DEVICE_STATE[device]
|
||||||
|
|
||||||
|
|
||||||
|
# (ADD) detect torchao wrapper tensors
|
||||||
|
def _is_ao_quantized_tensor(t: Optional[torch.Tensor]) -> bool:
|
||||||
|
if t is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if has_torch_function_unary(t):
|
||||||
|
return t.__class__.__module__.startswith("torchao.")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for attr in (
|
||||||
|
"_scale",
|
||||||
|
"_scales",
|
||||||
|
"_zero_point",
|
||||||
|
"_zp",
|
||||||
|
"_block_size",
|
||||||
|
"_group_size",
|
||||||
|
"_pack_dim",
|
||||||
|
):
|
||||||
|
if hasattr(t, attr):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_quantized_tensor(t: Optional[torch.Tensor]) -> bool:
|
||||||
|
if t is None:
|
||||||
|
return False
|
||||||
|
# torch quantized tensors
|
||||||
|
try:
|
||||||
|
if torch.is_quantized(t): # type: ignore[attr-defined]
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# (ADD) torchao quantized wrappers
|
||||||
|
if _is_ao_quantized_tensor(t):
|
||||||
|
return True
|
||||||
|
# packed/int formats (weight-only)
|
||||||
|
return not t.dtype.is_floating_point
|
||||||
|
|
||||||
|
|
||||||
def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
def _ensure_cpu_pinned(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||||
if t is None:
|
if t is None:
|
||||||
return None
|
return None
|
||||||
if t.device.type != "cpu":
|
if t.device.type != "cpu":
|
||||||
t = t.to("cpu", copy=True)
|
t = t.to("cpu", copy=True)
|
||||||
|
# Don't attempt to pin quantized tensors; many backends don't support it
|
||||||
|
if _is_quantized_tensor(t):
|
||||||
|
return t
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
t = t.pin_memory()
|
t = t.pin_memory()
|
||||||
@@ -86,8 +129,35 @@ def _move_params_to_cpu_and_pin(module: nn.Module):
|
|||||||
class _BouncingLinearFn(torch.autograd.Function):
|
class _BouncingLinearFn(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device):
|
def forward(ctx, x, weight_cpu, bias_cpu, device: torch.device):
|
||||||
|
# choose compute dtype to match activations
|
||||||
|
target_dtype = (
|
||||||
|
x.dtype
|
||||||
|
if x.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
||||||
|
else torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
# GPU-side dequant/cast for quantized; float path unchanged
|
||||||
|
def _materialize_linear_weight(cpu_w, dev):
|
||||||
|
if _is_quantized_tensor(cpu_w):
|
||||||
|
# move quantized wrapper to GPU -> dequantize on GPU -> cast on GPU
|
||||||
|
w_q_gpu = cpu_w.to(dev, non_blocking=True)
|
||||||
|
try:
|
||||||
|
w_fp_gpu = w_q_gpu.dequantize()
|
||||||
|
except Exception:
|
||||||
|
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
|
||||||
|
if w_fp_gpu.dtype != target_dtype:
|
||||||
|
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
|
||||||
|
return w_fp_gpu
|
||||||
|
# float path (preserve original behavior: NO dtype cast)
|
||||||
|
w_gpu = cpu_w.to(dev, non_blocking=True)
|
||||||
|
return w_gpu
|
||||||
|
|
||||||
if device.type != "cuda":
|
if device.type != "cuda":
|
||||||
out = F.linear(x.to("cpu"), weight_cpu, bias_cpu)
|
out = F.linear(
|
||||||
|
x.to("cpu"),
|
||||||
|
_materialize_linear_weight(weight_cpu, torch.device("cpu")),
|
||||||
|
bias_cpu,
|
||||||
|
)
|
||||||
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
||||||
ctx.device = torch.device("cpu")
|
ctx.device = torch.device("cpu")
|
||||||
return out.to(x.device)
|
return out.to(x.device)
|
||||||
@@ -101,7 +171,7 @@ class _BouncingLinearFn(torch.autograd.Function):
|
|||||||
|
|
||||||
with torch.cuda.stream(ts):
|
with torch.cuda.stream(ts):
|
||||||
ts.wait_event(ev_cu_s)
|
ts.wait_event(ev_cu_s)
|
||||||
w_bufs[idx] = weight_cpu.to(device, non_blocking=True)
|
w_bufs[idx] = _materialize_linear_weight(weight_cpu, device)
|
||||||
b_bufs[idx] = (
|
b_bufs[idx] = (
|
||||||
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
||||||
)
|
)
|
||||||
@@ -114,21 +184,39 @@ class _BouncingLinearFn(torch.autograd.Function):
|
|||||||
|
|
||||||
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
||||||
ctx.device = device
|
ctx.device = device
|
||||||
|
ctx.target_dtype = target_dtype
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
||||||
device = ctx.device
|
device = ctx.device
|
||||||
|
target_dtype = getattr(ctx, "target_dtype", grad_out.dtype)
|
||||||
|
|
||||||
if device.type != "cuda":
|
if device.type != "cuda":
|
||||||
go_cpu = grad_out.to("cpu")
|
go_cpu = grad_out.to("cpu")
|
||||||
x_cpu = x.to("cpu")
|
x_cpu = x.to("cpu")
|
||||||
grad_input = go_cpu @ weight_cpu
|
w_mat = (
|
||||||
grad_weight = go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2)
|
weight_cpu.dequantize()
|
||||||
|
if _is_quantized_tensor(weight_cpu)
|
||||||
|
else weight_cpu
|
||||||
|
)
|
||||||
|
if w_mat.dtype != target_dtype and target_dtype in (
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
):
|
||||||
|
w_mat = w_mat.to(target_dtype)
|
||||||
|
grad_input = go_cpu @ w_mat
|
||||||
|
grad_weight = (
|
||||||
|
go_cpu.flatten(0, -2).T @ x_cpu.flatten(0, -2)
|
||||||
|
if getattr(weight_cpu, "requires_grad", False)
|
||||||
|
and weight_cpu.dtype.is_floating_point
|
||||||
|
else None
|
||||||
|
)
|
||||||
grad_bias = (
|
grad_bias = (
|
||||||
go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1)))
|
go_cpu.sum(dim=tuple(range(go_cpu.ndim - 1)))
|
||||||
if bias_cpu is not None
|
if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False))
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return grad_input.to(grad_out.device), grad_weight, grad_bias, None
|
return grad_input.to(grad_out.device), grad_weight, grad_bias, None
|
||||||
@@ -148,45 +236,62 @@ class _BouncingLinearFn(torch.autograd.Function):
|
|||||||
|
|
||||||
idx = state["backward_clk"]
|
idx = state["backward_clk"]
|
||||||
|
|
||||||
# Stage weights onto device (transfer stream), ping-pong to avoid races
|
# GPU-side dequant/cast for quantized; float path unchanged
|
||||||
|
def _materialize_for_bwd(cpu_w):
|
||||||
|
if _is_quantized_tensor(cpu_w):
|
||||||
|
w_q_gpu = cpu_w.to(device, non_blocking=True)
|
||||||
|
try:
|
||||||
|
w_fp_gpu = w_q_gpu.dequantize()
|
||||||
|
except Exception:
|
||||||
|
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
|
||||||
|
if w_fp_gpu.dtype != target_dtype:
|
||||||
|
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
|
||||||
|
return w_fp_gpu
|
||||||
|
# float path (preserve original behavior: NO dtype cast)
|
||||||
|
w = cpu_w.to(device, non_blocking=True)
|
||||||
|
return w
|
||||||
|
|
||||||
with torch.cuda.stream(transfer_stream):
|
with torch.cuda.stream(transfer_stream):
|
||||||
transfer_stream.wait_event(ev_cu_b_start)
|
transfer_stream.wait_event(ev_cu_b_start)
|
||||||
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True)
|
w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu)
|
||||||
state["backward_clk"] ^= 1
|
state["backward_clk"] ^= 1
|
||||||
ev_tx_b.record()
|
ev_tx_b.record()
|
||||||
|
|
||||||
# Compute stream waits for weights to arrive, then start compute
|
|
||||||
torch.cuda.current_stream().wait_event(ev_tx_b)
|
torch.cuda.current_stream().wait_event(ev_tx_b)
|
||||||
ev_cu_b_start.record()
|
ev_cu_b_start.record()
|
||||||
|
|
||||||
# 1) Compute grad_input using the freshly transferred weights
|
# grad wrt input (GPU)
|
||||||
grad_input = grad_out @ w_bwd_buffers[idx]
|
grad_input = grad_out.to(dtype=target_dtype) @ w_bwd_buffers[idx]
|
||||||
|
|
||||||
# 2) Ensure previous grad-to-CPU transfer that used this slot finished
|
# ensure previous grad-to-CPU transfer that used this slot finished
|
||||||
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
||||||
|
|
||||||
# 3) Compute weight/bias grads on GPU into staging buffers
|
# compute grads if float masters exist
|
||||||
w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2)
|
grad_weight = None
|
||||||
if bias_cpu is not None:
|
grad_bias = None
|
||||||
|
if (
|
||||||
|
getattr(weight_cpu, "requires_grad", False)
|
||||||
|
and weight_cpu.dtype.is_floating_point
|
||||||
|
):
|
||||||
|
w_grad_buffers[idx] = grad_out.flatten(0, -2).T @ x.flatten(0, -2)
|
||||||
|
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
|
||||||
reduce_dims = tuple(range(grad_out.ndim - 1))
|
reduce_dims = tuple(range(grad_out.ndim - 1))
|
||||||
b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims)
|
b_grad_buffers[idx] = grad_out.sum(dim=reduce_dims)
|
||||||
|
|
||||||
# Mark end of GPU compute
|
|
||||||
ev_cu_b_finish.record()
|
ev_cu_b_finish.record()
|
||||||
|
|
||||||
# 4) Launch non-blocking H2D->CPU transfers on a separate grad stream (full-duplex)
|
|
||||||
with torch.cuda.stream(transfer_grad_stream):
|
with torch.cuda.stream(transfer_grad_stream):
|
||||||
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
||||||
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
if (
|
||||||
grad_bias = (
|
getattr(weight_cpu, "requires_grad", False)
|
||||||
b_grad_buffers[idx].to("cpu", non_blocking=True)
|
and weight_cpu.dtype.is_floating_point
|
||||||
if bias_cpu is not None
|
):
|
||||||
else None
|
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||||
)
|
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
|
||||||
# signal that this slot's CPU transfer is complete (safe for next reuse)
|
grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||||
state["transfer_weight_backward_finished_event"].record()
|
state["transfer_weight_backward_finished_event"].record()
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None
|
return grad_input.to(dtype=grad_out.dtype), grad_weight, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
class _BouncingConv2dFn(torch.autograd.Function):
|
class _BouncingConv2dFn(torch.autograd.Function):
|
||||||
@@ -202,12 +307,39 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
dilation: Tuple[int, int],
|
dilation: Tuple[int, int],
|
||||||
groups: int,
|
groups: int,
|
||||||
):
|
):
|
||||||
|
target_dtype = (
|
||||||
|
x.dtype
|
||||||
|
if x.dtype in (torch.bfloat16, torch.float16, torch.float32)
|
||||||
|
else torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
# GPU-side dequant/cast for quantized; float path unchanged
|
||||||
|
def _materialize_conv_weight(cpu_w, dev):
|
||||||
|
if _is_quantized_tensor(cpu_w):
|
||||||
|
w_q_gpu = cpu_w.to(dev, non_blocking=True)
|
||||||
|
try:
|
||||||
|
w_fp_gpu = w_q_gpu.dequantize()
|
||||||
|
except Exception:
|
||||||
|
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
|
||||||
|
if w_fp_gpu.dtype != target_dtype:
|
||||||
|
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
|
||||||
|
return w_fp_gpu
|
||||||
|
# float path (preserve original behavior: NO dtype cast)
|
||||||
|
w_gpu = cpu_w.to(dev, non_blocking=True)
|
||||||
|
return w_gpu
|
||||||
|
|
||||||
if device.type != "cuda":
|
if device.type != "cuda":
|
||||||
out = F.conv2d(
|
out = F.conv2d(
|
||||||
x.to("cpu"), weight_cpu, bias_cpu, stride, padding, dilation, groups
|
x.to("cpu"),
|
||||||
|
_materialize_conv_weight(weight_cpu, torch.device("cpu")),
|
||||||
|
bias_cpu,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
ctx.save_for_backward(x.to("cpu"), weight_cpu, bias_cpu)
|
||||||
ctx.meta = ("cpu", stride, padding, dilation, groups)
|
ctx.meta = ("cpu", stride, padding, dilation, groups, target_dtype)
|
||||||
return out.to(x.device)
|
return out.to(x.device)
|
||||||
|
|
||||||
state = _get_device_state(device)
|
state = _get_device_state(device)
|
||||||
@@ -219,7 +351,7 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
|
|
||||||
with torch.cuda.stream(ts):
|
with torch.cuda.stream(ts):
|
||||||
ts.wait_event(ev_cu_s)
|
ts.wait_event(ev_cu_s)
|
||||||
w_bufs[idx] = weight_cpu.to(device, non_blocking=True)
|
w_bufs[idx] = _materialize_conv_weight(weight_cpu, device)
|
||||||
b_bufs[idx] = (
|
b_bufs[idx] = (
|
||||||
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
bias_cpu.to(device, non_blocking=True) if bias_cpu is not None else None
|
||||||
)
|
)
|
||||||
@@ -231,22 +363,30 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups)
|
out = F.conv2d(x, w_bufs[idx], b_bufs[idx], stride, padding, dilation, groups)
|
||||||
|
|
||||||
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
ctx.save_for_backward(x, weight_cpu, bias_cpu)
|
||||||
ctx.meta = (device, stride, padding, dilation, groups)
|
ctx.meta = (device, stride, padding, dilation, groups, target_dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
x, weight_cpu, bias_cpu = ctx.saved_tensors
|
||||||
meta = ctx.meta
|
device, stride, padding, dilation, groups, target_dtype = ctx.meta
|
||||||
device, stride, padding, dilation, groups = meta
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(device, torch.device) and device.type != "cuda"
|
isinstance(device, torch.device) and device.type != "cuda"
|
||||||
) or device == "cpu":
|
) or device == "cpu":
|
||||||
# CPU grads
|
|
||||||
go = grad_out.to("cpu")
|
go = grad_out.to("cpu")
|
||||||
x_cpu = x.to("cpu")
|
x_cpu = x.to("cpu")
|
||||||
w_cpu = weight_cpu
|
w_cpu = (
|
||||||
|
weight_cpu.dequantize()
|
||||||
|
if _is_quantized_tensor(weight_cpu)
|
||||||
|
else weight_cpu
|
||||||
|
)
|
||||||
|
if w_cpu.dtype != target_dtype and target_dtype in (
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
torch.float32,
|
||||||
|
):
|
||||||
|
w_cpu = w_cpu.to(target_dtype)
|
||||||
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
||||||
|
|
||||||
grad_input = conv2d_input(
|
grad_input = conv2d_input(
|
||||||
@@ -258,16 +398,25 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
)
|
)
|
||||||
grad_weight = conv2d_weight(
|
grad_weight = (
|
||||||
x_cpu,
|
conv2d_weight(
|
||||||
w_cpu.shape,
|
x_cpu,
|
||||||
go,
|
w_cpu.shape,
|
||||||
stride=stride,
|
go,
|
||||||
padding=padding,
|
stride=stride,
|
||||||
dilation=dilation,
|
padding=padding,
|
||||||
groups=groups,
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
if getattr(weight_cpu, "requires_grad", False)
|
||||||
|
and weight_cpu.dtype.is_floating_point
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
grad_bias = (
|
||||||
|
go.sum(dim=(0, 2, 3))
|
||||||
|
if (bias_cpu is not None and getattr(bias_cpu, "requires_grad", False))
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
grad_bias = go.sum(dim=(0, 2, 3)) if bias_cpu is not None else None
|
|
||||||
return (
|
return (
|
||||||
grad_input.to(grad_out.device),
|
grad_input.to(grad_out.device),
|
||||||
grad_weight,
|
grad_weight,
|
||||||
@@ -279,12 +428,10 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CUDA path (full-duplex)
|
|
||||||
state = _get_device_state(device)
|
state = _get_device_state(device)
|
||||||
transfer_stream = state["transfer_stream"]
|
transfer_stream = state["transfer_stream"]
|
||||||
transfer_grad_stream = state["transfer_grad_stream"]
|
transfer_grad_stream = state["transfer_grad_stream"]
|
||||||
|
|
||||||
# device-side buffers
|
|
||||||
w_bwd_buffers = state["w_bwd_buffers"]
|
w_bwd_buffers = state["w_bwd_buffers"]
|
||||||
w_grad_buffers = state["w_grad_buffers"]
|
w_grad_buffers = state["w_grad_buffers"]
|
||||||
b_grad_buffers = state["b_grad_buffers"]
|
b_grad_buffers = state["b_grad_buffers"]
|
||||||
@@ -296,23 +443,37 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
|
|
||||||
idx = state["backward_clk"]
|
idx = state["backward_clk"]
|
||||||
|
|
||||||
|
# GPU-side dequant/cast for quantized; float path unchanged
|
||||||
|
def _materialize_for_bwd(cpu_w):
|
||||||
|
if _is_quantized_tensor(cpu_w):
|
||||||
|
w_q_gpu = cpu_w.to(device, non_blocking=True)
|
||||||
|
try:
|
||||||
|
w_fp_gpu = w_q_gpu.dequantize()
|
||||||
|
except Exception:
|
||||||
|
w_fp_gpu = w_q_gpu.to(dtype=torch.float32, non_blocking=True)
|
||||||
|
if w_fp_gpu.dtype != target_dtype:
|
||||||
|
w_fp_gpu = w_fp_gpu.to(target_dtype, non_blocking=True)
|
||||||
|
return w_fp_gpu
|
||||||
|
# float path (preserve original behavior: NO dtype cast)
|
||||||
|
w = cpu_w.to(device, non_blocking=True)
|
||||||
|
return w
|
||||||
|
|
||||||
# Stage weights for input-grad compute
|
# Stage weights for input-grad compute
|
||||||
with torch.cuda.stream(transfer_stream):
|
with torch.cuda.stream(transfer_stream):
|
||||||
transfer_stream.wait_event(ev_cu_b_start)
|
transfer_stream.wait_event(ev_cu_b_start)
|
||||||
w_bwd_buffers[idx] = weight_cpu.to(device, non_blocking=True)
|
w_bwd_buffers[idx] = _materialize_for_bwd(weight_cpu)
|
||||||
state["backward_clk"] ^= 1
|
state["backward_clk"] ^= 1
|
||||||
ev_tx_b.record()
|
ev_tx_b.record()
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_event(ev_tx_b)
|
torch.cuda.current_stream().wait_event(ev_tx_b)
|
||||||
ev_cu_b_start.record()
|
ev_cu_b_start.record()
|
||||||
|
|
||||||
# grad wrt input on GPU with streamed weights
|
|
||||||
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
from torch.nn.grad import conv2d_input, conv2d_weight # type: ignore
|
||||||
|
|
||||||
grad_input = conv2d_input(
|
grad_input = conv2d_input(
|
||||||
x.shape,
|
x.shape,
|
||||||
w_bwd_buffers[idx],
|
w_bwd_buffers[idx],
|
||||||
grad_out,
|
grad_out.to(dtype=target_dtype),
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
@@ -323,33 +484,48 @@ class _BouncingConv2dFn(torch.autograd.Function):
|
|||||||
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
torch.cuda.current_stream().wait_event(ev_tx_w_bwd_done)
|
||||||
|
|
||||||
# Compute heavy grads on GPU into staging buffers
|
# Compute heavy grads on GPU into staging buffers
|
||||||
w_grad_buffers[idx] = conv2d_weight(
|
grad_weight = None
|
||||||
x,
|
grad_bias = None
|
||||||
weight_cpu.shape,
|
if (
|
||||||
grad_out,
|
getattr(weight_cpu, "requires_grad", False)
|
||||||
stride=stride,
|
and weight_cpu.dtype.is_floating_point
|
||||||
padding=padding,
|
):
|
||||||
dilation=dilation,
|
w_grad_buffers[idx] = conv2d_weight(
|
||||||
groups=groups,
|
x,
|
||||||
)
|
weight_cpu.shape,
|
||||||
if bias_cpu is not None:
|
grad_out,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
)
|
||||||
|
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
|
||||||
b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3))
|
b_grad_buffers[idx] = grad_out.sum(dim=(0, 2, 3))
|
||||||
|
|
||||||
# Mark end of GPU math
|
|
||||||
ev_cu_b_finish.record()
|
ev_cu_b_finish.record()
|
||||||
|
|
||||||
# Launch CPU copies on the dedicated grad stream (overlaps with next H2D)
|
# Launch CPU copies on the dedicated grad stream (overlaps with next H2D)
|
||||||
with torch.cuda.stream(transfer_grad_stream):
|
with torch.cuda.stream(transfer_grad_stream):
|
||||||
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
transfer_grad_stream.wait_event(ev_cu_b_finish)
|
||||||
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
if (
|
||||||
grad_bias = (
|
getattr(weight_cpu, "requires_grad", False)
|
||||||
b_grad_buffers[idx].to("cpu", non_blocking=True)
|
and weight_cpu.dtype.is_floating_point
|
||||||
if bias_cpu is not None
|
):
|
||||||
else None
|
grad_weight = w_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||||
)
|
if bias_cpu is not None and getattr(bias_cpu, "requires_grad", False):
|
||||||
|
grad_bias = b_grad_buffers[idx].to("cpu", non_blocking=True)
|
||||||
state["transfer_weight_backward_finished_event"].record()
|
state["transfer_weight_backward_finished_event"].record()
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None
|
return (
|
||||||
|
grad_input.to(dtype=grad_out.dtype),
|
||||||
|
grad_weight,
|
||||||
|
grad_bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseLayerMemoryManager:
|
class BaseLayerMemoryManager:
|
||||||
|
|||||||
Reference in New Issue
Block a user