mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
Add vendor control
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include <torch/torch.h>
|
||||
#include <cstdint>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
typedef hip_bfloat16 nv_bfloat16;
|
||||
|
||||
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
|
||||
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
@@ -31,6 +31,7 @@ from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
custom_models = {
|
||||
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
|
||||
@@ -169,7 +170,7 @@ def local_chat(
|
||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
|
||||
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:
|
||||
generated = prefill_and_generate(
|
||||
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
|
||||
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||
|
||||
@@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except:
|
||||
pass
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
@@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
# for bsz = 1
|
||||
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
|
||||
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
|
||||
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
|
||||
|
||||
max_input_len = q_len
|
||||
|
||||
context_attention_fwd(
|
||||
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
|
||||
o=attn_output,
|
||||
b_start_loc=b_start_loc,
|
||||
b_seq_len=b_seq_len,
|
||||
max_input_len=max_input_len,
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
attn_output = attn_output[:, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
@@ -589,7 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
|
||||
@@ -17,7 +17,10 @@ import logging
|
||||
logger = logging.getLogger("dynamic_attention")
|
||||
sys.path.append(os.path.dirname(__file__) + "/../ktransformers_ext/cpu_backend")
|
||||
from ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
except:
|
||||
print("falsh attn not found")
|
||||
|
||||
|
||||
import math
|
||||
|
||||
@@ -53,6 +53,7 @@ from ktransformers.models.modeling_deepseek import (
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MoE,
|
||||
)
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
@@ -649,7 +650,7 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||
if per_layer_prefill_flag:
|
||||
causal_mask = None
|
||||
else:
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
# only use mask in forward windows or can't flash attn
|
||||
causal_mask = self._update_causal_mask(
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
@@ -181,8 +181,8 @@ def _decode_grouped_att_m_fwd(
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
|
||||
# TODO: support hip
|
||||
#if is_hip_ and Lk >= 576:
|
||||
# BLOCK = 16
|
||||
if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
|
||||
206
ktransformers/operators/triton_attention_prefill.py
Normal file
206
ktransformers/operators/triton_attention_prefill.py
Normal file
@@ -0,0 +1,206 @@
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supporst page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
mask_d = offs_d < Lk
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
end_n = (
|
||||
cur_batch_seq_len
|
||||
if not IS_CAUSAL
|
||||
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
|
||||
)
|
||||
for start_n in range(0, block_mask * end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(
|
||||
(start_n + offs_n[None, :] < cur_batch_seq_len)
|
||||
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
|
||||
0,
|
||||
float("-inf"),
|
||||
)
|
||||
else:
|
||||
qk += tl.where(
|
||||
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
|
||||
)
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(
|
||||
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
||||
)
|
||||
|
||||
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
"""
|
||||
q, k, v: [b * s, head, head_dim]
|
||||
b_start_loc: [b]
|
||||
b_seq_len: [b]
|
||||
out: [b * s, head, head_dim]
|
||||
"""
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||
BLOCK_N=BLOCK,
|
||||
IS_CAUSAL=is_causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lk=Lk,
|
||||
)
|
||||
202
ktransformers/util/vendors.py
Normal file
202
ktransformers/util/vendors.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional, Union, List
|
||||
import torch
|
||||
|
||||
class GPUVendor(IntEnum):
|
||||
NVIDIA = auto()
|
||||
AMD = auto()
|
||||
MooreThreads = auto()
|
||||
MetaX = auto()
|
||||
MUSA = auto()
|
||||
Unknown = auto()
|
||||
|
||||
class DeviceManager:
|
||||
"""
|
||||
Device manager that provides a unified interface for handling different GPU vendors
|
||||
"""
|
||||
def __init__(self):
|
||||
self.gpu_vendor = self._detect_gpu_vendor()
|
||||
self.available_devices = self._get_available_devices()
|
||||
|
||||
def _detect_gpu_vendor(self) -> GPUVendor:
|
||||
"""Detect GPU vendor type"""
|
||||
if not torch.cuda.is_available():
|
||||
# Check MUSA availability (assuming a musa module exists)
|
||||
try:
|
||||
import musa
|
||||
if musa.is_available():
|
||||
return GPUVendor.MUSA
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return GPUVendor.Unknown
|
||||
|
||||
device_name = torch.cuda.get_device_name(0).lower()
|
||||
|
||||
if any(name in device_name for name in ["nvidia", "geforce", "quadro", "tesla", "titan", "rtx", "gtx"]):
|
||||
return GPUVendor.NVIDIA
|
||||
elif any(name in device_name for name in ["amd", "radeon", "rx", "vega", "instinct", "firepro", "mi"]):
|
||||
return GPUVendor.AMD
|
||||
elif any(name in device_name for name in ["mthreads", "moore", "mtt"]):
|
||||
return GPUVendor.MooreThreads
|
||||
elif any(name in device_name for name in ["metax", "meta"]):
|
||||
return GPUVendor.MetaX
|
||||
elif "musa" in device_name:
|
||||
return GPUVendor.MUSA
|
||||
|
||||
# Backend check
|
||||
try:
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip is not None:
|
||||
return GPUVendor.AMD
|
||||
elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:
|
||||
return GPUVendor.NVIDIA
|
||||
except:
|
||||
pass
|
||||
|
||||
return GPUVendor.Unknown
|
||||
|
||||
def _get_available_devices(self) -> List[int]:
|
||||
"""Get list of available device indices"""
|
||||
devices = []
|
||||
|
||||
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
||||
devices = list(range(torch.cuda.device_count()))
|
||||
elif self.gpu_vendor == GPUVendor.MUSA:
|
||||
try:
|
||||
import musa
|
||||
devices = list(range(musa.device_count()))
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return devices
|
||||
|
||||
def get_device_str(self, device_id: Union[int, str]) -> str:
|
||||
"""
|
||||
Get device string for the given device ID
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Device string representation (e.g., "cuda:0", "musa:1", "cpu")
|
||||
"""
|
||||
if device_id == -1 or device_id == "cpu":
|
||||
return "cpu"
|
||||
|
||||
if isinstance(device_id, int):
|
||||
if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:
|
||||
if device_id < torch.cuda.device_count():
|
||||
return f"cuda:{device_id}"
|
||||
elif self.gpu_vendor == GPUVendor.MUSA:
|
||||
try:
|
||||
import musa
|
||||
if device_id < musa.device_count():
|
||||
return f"musa:{device_id}"
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return "cpu"
|
||||
|
||||
def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:
|
||||
"""
|
||||
Convert device ID to torch.device object
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
torch.device object
|
||||
"""
|
||||
device_str = self.get_device_str(device_id)
|
||||
|
||||
# Handle MUSA device
|
||||
if device_str.startswith("musa:"):
|
||||
try:
|
||||
import musa
|
||||
index = int(device_str.split(":")[-1])
|
||||
return musa.device(index)
|
||||
except (ImportError, ValueError, AttributeError):
|
||||
return torch.device("cpu")
|
||||
|
||||
# Standard PyTorch device
|
||||
return torch.device(device_str)
|
||||
|
||||
def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
||||
"""
|
||||
Move tensor to specified device
|
||||
|
||||
Args:
|
||||
tensor: PyTorch tensor to move
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Tensor moved to the specified device
|
||||
"""
|
||||
device = self.to_torch_device(device_id)
|
||||
return tensor.to(device)
|
||||
|
||||
def is_available(self, index: int = 0) -> bool:
|
||||
"""
|
||||
Check if device at specified index is available
|
||||
|
||||
Args:
|
||||
index: Device index to check
|
||||
|
||||
Returns:
|
||||
True if the device is available, False otherwise
|
||||
"""
|
||||
if index < 0:
|
||||
return True # CPU is always available
|
||||
|
||||
return index in self.available_devices
|
||||
|
||||
def get_all_devices(self) -> List[int]:
|
||||
"""
|
||||
Get all available device indices
|
||||
|
||||
Returns:
|
||||
List of available device indices (0, 1, 2, etc.)
|
||||
"""
|
||||
return self.available_devices
|
||||
|
||||
# Create global device manager instance
|
||||
device_manager = DeviceManager()
|
||||
|
||||
# Convenience functions
|
||||
def get_device(device_id: Union[int, str] = 0) -> torch.device:
|
||||
"""
|
||||
Get torch.device object for the specified device ID
|
||||
|
||||
Args:
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
torch.device object
|
||||
"""
|
||||
return device_manager.to_torch_device(device_id)
|
||||
|
||||
def to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:
|
||||
"""
|
||||
Move tensor to specified device
|
||||
|
||||
Args:
|
||||
tensor: PyTorch tensor to move
|
||||
device_id: Device index (0, 1, 2, etc.), -1 for CPU, or "cpu" string
|
||||
|
||||
Returns:
|
||||
Tensor moved to the specified device
|
||||
"""
|
||||
return device_manager.move_tensor_to_device(tensor, device_id)
|
||||
|
||||
# Get devices
|
||||
cpu_device = get_device(-1) # CPU using index -1
|
||||
cpu_device2 = get_device("cpu") # CPU using string "cpu"
|
||||
gpu0 = get_device(0) # First GPU
|
||||
|
||||
# Move tensors
|
||||
x = torch.randn(3, 3)
|
||||
x_gpu = to_device(x, 0) # Move to first GPU
|
||||
x_cpu1 = to_device(x, -1) # Move to CPU using index -1
|
||||
x_cpu2 = to_device(x, "cpu") # Move to CPU using string "cpu"
|
||||
Reference in New Issue
Block a user