mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 23:34:49 +00:00
Merge remote-tracking branch 'origin/main' into qinghuazhou/expert_parallel_merge_main_test
# Conflicts: # src/core/connection.cc # test/mp_unit/port_channel_tests.cu
This commit is contained in:
@@ -45,8 +45,10 @@ void register_core(nb::module_& m) {
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16)
|
||||
.value("float8_e4m3", DataType::FLOAT8_E4M3)
|
||||
.value("float8_e4m3fn", DataType::FLOAT8_E4M3FN)
|
||||
.value("float8_e4m3fnuz", DataType::FLOAT8_E4M3FNUZ)
|
||||
.value("float8_e5m2", DataType::FLOAT8_E5M2)
|
||||
.value("float8_e5m2fnuz", DataType::FLOAT8_E5M2FNUZ)
|
||||
.value("uint8", DataType::UINT8)
|
||||
.value("float8_e4m3b15", DataType::FLOAT8_E4M3B15);
|
||||
|
||||
@@ -328,4 +330,4 @@ NB_MODULE(_mscclpp, m) {
|
||||
|
||||
// ext
|
||||
register_algorithm_collection_builder(m);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ from mscclpp.language.utils import AlgoSpec
|
||||
default_algo_configs = [
|
||||
{
|
||||
"filename": "allreduce_2nodes_1K_64K.json",
|
||||
"function": def_algo.allreduce_2nodes,
|
||||
"function": def_algo.allreduce_multi_nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_2nodes_1K_64K",
|
||||
collective=AllReduce(16, 1, True),
|
||||
@@ -34,7 +34,7 @@ default_algo_configs = [
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_2nodes_128K_2M.json",
|
||||
"function": def_algo.allreduce_2nodes,
|
||||
"function": def_algo.allreduce_multi_nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_2nodes_128K_2M",
|
||||
collective=AllReduce(16, 1, True),
|
||||
@@ -53,6 +53,48 @@ default_algo_configs = [
|
||||
),
|
||||
"additional_kwargs": {"thread_block_group_size": 4},
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_4nodes_1K_64K.json",
|
||||
"function": def_algo.allreduce_multi_nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_4nodes_1K_64K",
|
||||
collective=AllReduce(32, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=32,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="LL",
|
||||
auto_sync=False,
|
||||
num_threads_per_block=1024,
|
||||
reuse_resources=True,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=1 << 10,
|
||||
max_message_size=64 << 10,
|
||||
tags={"default": 1},
|
||||
),
|
||||
"additional_kwargs": {"thread_block_group_size": 1},
|
||||
},
|
||||
{
|
||||
"filename": "allreduce_4nodes_128K_2M.json",
|
||||
"function": def_algo.allreduce_multi_nodes,
|
||||
"spec": AlgoSpec(
|
||||
name="allreduce_4nodes_128K_2M",
|
||||
collective=AllReduce(32, 1, True),
|
||||
nranks_per_node=8,
|
||||
world_size=32,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="LL",
|
||||
auto_sync=False,
|
||||
num_threads_per_block=1024,
|
||||
reuse_resources=True,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=128 << 10,
|
||||
max_message_size=2 << 20,
|
||||
tags={"default": 1},
|
||||
),
|
||||
"additional_kwargs": {"thread_block_group_size": 4},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from mscclpp.default_algos.allreduce_2nodes import allreduce_2nodes
|
||||
from mscclpp.default_algos.allreduce_multi_nodes import allreduce_multi_nodes
|
||||
|
||||
__all__ = ["allreduce_2nodes"]
|
||||
__all__ = ["allreduce_multi_nodes"]
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Multi-node AllReduce implementation using packet-based communication.
|
||||
This implements a hierarchical AllReduce: intra-node allreduce followed by
|
||||
inter-node exchange and final intra-node allreduce.
|
||||
Generalized multi-node AllReduce implementation using packet-based communication.
|
||||
This implements a hierarchical AllReduce for N nodes:
|
||||
1. Intra-node reduce-scatter (each GPU reduces its assigned chunk across the node)
|
||||
2. Inter-node allreduce (exchange fully intra-reduced chunks across all nodes)
|
||||
3. Intra-node broadcast (distribute the fully reduced chunks back to all GPUs in the node)
|
||||
"""
|
||||
|
||||
from mscclpp.language.utils import AlgoSpec
|
||||
@@ -15,7 +17,7 @@ from mscclpp.language.program import *
|
||||
from mscclpp.language.collectives import *
|
||||
|
||||
|
||||
def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram:
|
||||
def allreduce_multi_nodes(spec: AlgoSpec, thread_block_group_size: int) -> CollectiveProgram:
|
||||
"""
|
||||
Implements a multi-node AllReduce using a hierarchical approach:
|
||||
1. Intra-node allreduce
|
||||
@@ -23,10 +25,10 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
3. Intra-node allreduce
|
||||
"""
|
||||
# Configuration constants
|
||||
num_nodes = 2
|
||||
num_nodes = spec.world_size // spec.nranks_per_node
|
||||
gpus_per_node = spec.nranks_per_node
|
||||
total_gpus = num_nodes * gpus_per_node
|
||||
packets_per_gpu = 2
|
||||
packets_per_gpu = num_nodes
|
||||
|
||||
with CollectiveProgram.from_spec(spec) as prog:
|
||||
# Initialize communication channels and buffers
|
||||
@@ -54,11 +56,21 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
)
|
||||
)
|
||||
|
||||
scratch_buffer_size = packets_per_gpu * (total_gpus + 1)
|
||||
# Scratch buffer layout (3 contiguous regions):
|
||||
# Region 1 [0, total_gpus):
|
||||
# Intra-node reduce-scatter. Each GPU receives chunks from gpus_per_node peers,
|
||||
# packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots.
|
||||
# Region 2 [total_gpus, total_gpus + num_nodes * packets_per_gpu):
|
||||
# Inter-node exchange. Each GPU receives reduced chunks from num_nodes nodes,
|
||||
# packets_per_gpu each → num_nodes * packets_per_gpu slots.
|
||||
# Region 3 [total_gpus + num_nodes * packets_per_gpu, end):
|
||||
# Intra-node broadcast. Each GPU receives final reduced data from gpus_per_node peers,
|
||||
# packets_per_gpu each → gpus_per_node * packets_per_gpu = total_gpus slots.
|
||||
# Total = 2 * total_gpus + num_nodes * packets_per_gpu
|
||||
scratch_buffer_size = 2 * total_gpus + packets_per_gpu * num_nodes
|
||||
for node_id in range(num_nodes):
|
||||
for local_gpu_id in range(gpus_per_node):
|
||||
current_rank_id = local_gpu_id + gpus_per_node * node_id
|
||||
next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus
|
||||
scratch_buffers.append(Buffer(current_rank_id, scratch_buffer_size))
|
||||
for peer_gpu_id in range(gpus_per_node):
|
||||
if peer_gpu_id != local_gpu_id:
|
||||
@@ -66,7 +78,12 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
intra_node_memory_channels[(peer_rank_id, current_rank_id)] = MemoryChannel(
|
||||
peer_rank_id, current_rank_id
|
||||
)
|
||||
inter_node_port_channels[current_rank_id] = PortChannel(next_node_rank_id, current_rank_id)
|
||||
for peer_node_id in range(num_nodes):
|
||||
if peer_node_id != node_id:
|
||||
peer_node_rank_id = (local_gpu_id + gpus_per_node * peer_node_id) % total_gpus
|
||||
inter_node_port_channels[(current_rank_id, peer_node_rank_id)] = PortChannel(
|
||||
peer_node_rank_id, current_rank_id
|
||||
)
|
||||
|
||||
# AllReduce
|
||||
for node_id in range(num_nodes):
|
||||
@@ -74,7 +91,6 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
current_rank_id = local_gpu_id + gpus_per_node * node_id
|
||||
current_rank = Rank(current_rank_id)
|
||||
input_buffer = current_rank.get_input_buffer()
|
||||
next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus
|
||||
|
||||
# Intra Node Exchange Data
|
||||
for peer_gpu_id in range(gpus_per_node):
|
||||
@@ -118,27 +134,32 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
)
|
||||
|
||||
inter_node_offset = total_gpus
|
||||
inter_node_port_channels[current_rank_id].put_packets(
|
||||
scratch_buffers[next_node_rank_id][
|
||||
inter_node_offset
|
||||
+ local_gpu_id * packets_per_gpu : inter_node_offset
|
||||
+ local_gpu_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
scratch_buffers[current_rank_id][
|
||||
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
|
||||
],
|
||||
tb=0,
|
||||
)
|
||||
for peer_node_id in range(num_nodes):
|
||||
if peer_node_id != node_id:
|
||||
peer_node_rank_id = (local_gpu_id + gpus_per_node * peer_node_id) % total_gpus
|
||||
inter_node_port_channels[(current_rank_id, peer_node_rank_id)].put_packets(
|
||||
scratch_buffers[peer_node_rank_id][
|
||||
inter_node_offset
|
||||
+ node_id * packets_per_gpu : inter_node_offset
|
||||
+ node_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
scratch_buffers[current_rank_id][
|
||||
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
|
||||
],
|
||||
tb=0,
|
||||
)
|
||||
|
||||
# Reduce Received Data from Remote Node
|
||||
inter_node_data = [
|
||||
scratch_buffers[current_rank_id][
|
||||
inter_node_offset
|
||||
+ local_gpu_id * packets_per_gpu : inter_node_offset
|
||||
+ local_gpu_id * packets_per_gpu
|
||||
+ peer_node_id * packets_per_gpu : inter_node_offset
|
||||
+ peer_node_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
]
|
||||
for peer_node_id in range(num_nodes)
|
||||
if peer_node_id != node_id
|
||||
]
|
||||
current_rank.reduce(
|
||||
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
|
||||
@@ -148,12 +169,18 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
)
|
||||
|
||||
current_rank.copy_packets(
|
||||
scratch_buffers[current_rank_id][scratch_buffer_size - packets_per_gpu : scratch_buffer_size],
|
||||
scratch_buffers[current_rank_id][
|
||||
inter_node_offset
|
||||
+ node_id * packets_per_gpu : inter_node_offset
|
||||
+ node_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
|
||||
tb_group=global_intra_node_tbg,
|
||||
)
|
||||
|
||||
# Broadcast Reduced Data
|
||||
broadcast_offset = total_gpus + packets_per_gpu * num_nodes
|
||||
for peer_gpu_id in range(gpus_per_node):
|
||||
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
|
||||
|
||||
@@ -161,13 +188,16 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
tbg_id = peer_gpu_id if peer_gpu_id < local_gpu_id else peer_gpu_id - 1
|
||||
intra_node_memory_channels[(peer_rank_id, current_rank_id)].read_put_packets(
|
||||
scratch_buffers[peer_rank_id][
|
||||
inter_node_offset
|
||||
+ local_gpu_id * packets_per_gpu : inter_node_offset
|
||||
broadcast_offset
|
||||
+ local_gpu_id * packets_per_gpu : broadcast_offset
|
||||
+ local_gpu_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
scratch_buffers[current_rank_id][
|
||||
scratch_buffer_size - packets_per_gpu : scratch_buffer_size
|
||||
inter_node_offset
|
||||
+ node_id * packets_per_gpu : inter_node_offset
|
||||
+ node_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
tb_group=thread_block_groups[tbg_id],
|
||||
)
|
||||
@@ -181,8 +211,8 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
peer_gpu_id * packets_per_gpu : peer_gpu_id * packets_per_gpu + packets_per_gpu
|
||||
],
|
||||
scratch_buffers[current_rank_id][
|
||||
inter_node_offset
|
||||
+ peer_gpu_id * packets_per_gpu : inter_node_offset
|
||||
broadcast_offset
|
||||
+ peer_gpu_id * packets_per_gpu : broadcast_offset
|
||||
+ peer_gpu_id * packets_per_gpu
|
||||
+ packets_per_gpu
|
||||
],
|
||||
@@ -190,3 +220,37 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
|
||||
)
|
||||
|
||||
return prog
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name", type=str, help="name of the program")
|
||||
parser.add_argument("--num_gpus", type=int, help="total number of gpus")
|
||||
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
|
||||
parser.add_argument("--tbg", type=int, default=1, help="thread block group size")
|
||||
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
|
||||
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
|
||||
parser.add_argument("--max_message_size", type=int, default=2**64 - 1, help="maximum message size")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
spec = AlgoSpec(
|
||||
name=args.name,
|
||||
collective=AllReduce(args.num_gpus, 1, True),
|
||||
nranks_per_node=args.gpus_per_node,
|
||||
world_size=args.num_gpus,
|
||||
in_place=True,
|
||||
instances=1,
|
||||
protocol="LL",
|
||||
auto_sync=False,
|
||||
num_threads_per_block=args.num_threads_per_block,
|
||||
reuse_resources=True,
|
||||
use_double_scratch_buffer=True,
|
||||
min_message_size=args.min_message_size,
|
||||
max_message_size=args.max_message_size,
|
||||
)
|
||||
|
||||
prog = allreduce_multi_nodes(spec, args.tbg)
|
||||
print(prog.to_json())
|
||||
@@ -192,12 +192,14 @@ def torch_dtype_to_mscclpp_dtype(dtype: "torch.dtype") -> DataType:
|
||||
return DataType.int32
|
||||
elif dtype == torch.bfloat16:
|
||||
return DataType.bfloat16
|
||||
# Hardware supports either OCP format or FNUZ format for float8.
|
||||
# Mapping both to the same MSCClPP data type.
|
||||
elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
|
||||
elif dtype == torch.float8_e5m2:
|
||||
return DataType.float8_e5m2
|
||||
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3
|
||||
elif dtype == torch.float8_e5m2fnuz:
|
||||
return DataType.float8_e5m2fnuz
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
return DataType.float8_e4m3fn
|
||||
elif dtype == torch.float8_e4m3fnuz:
|
||||
return DataType.float8_e4m3fnuz
|
||||
elif dtype == torch.uint8:
|
||||
return DataType.uint8
|
||||
else:
|
||||
|
||||
@@ -24,6 +24,8 @@ def parse_dtype(dtype_str):
|
||||
dtype_str = dtype_str.strip().lower()
|
||||
if dtype_str == "float16":
|
||||
return cp.float16
|
||||
elif dtype_str in ("bfloat16", "bf16"):
|
||||
return cp.float16 # same 2-byte size; mscclpp DataType is resolved from dtype_str
|
||||
elif dtype_str == "float32":
|
||||
return cp.float32
|
||||
elif dtype_str == "int32":
|
||||
@@ -119,15 +121,18 @@ def parse_size(size_str):
|
||||
return int(size_str)
|
||||
|
||||
|
||||
def dtype_to_mscclpp_dtype(dtype):
|
||||
if dtype == cp.float16:
|
||||
def dtype_to_mscclpp_dtype(dtype_str):
|
||||
dtype_str = dtype_str.strip().lower()
|
||||
if dtype_str == "float16":
|
||||
return DataType.float16
|
||||
elif dtype == cp.float32:
|
||||
elif dtype_str in ("bfloat16", "bf16"):
|
||||
return DataType.bfloat16
|
||||
elif dtype_str == "float32":
|
||||
return DataType.float32
|
||||
elif dtype == cp.int32:
|
||||
elif dtype_str == "int32":
|
||||
return DataType.int32
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {dtype}")
|
||||
raise ValueError(f"Unknown data type: {dtype_str}")
|
||||
|
||||
|
||||
def build_bufs(
|
||||
@@ -205,7 +210,7 @@ def main(
|
||||
result_buf.data.ptr,
|
||||
input_buf.nbytes,
|
||||
result_buf.nbytes,
|
||||
dtype_to_mscclpp_dtype(dtype),
|
||||
dtype_to_mscclpp_dtype(dtype_str),
|
||||
execution_plan,
|
||||
stream.ptr,
|
||||
packet_type,
|
||||
@@ -231,7 +236,7 @@ def main(
|
||||
npkit.shutdown()
|
||||
print(
|
||||
f"Rank: {mscclpp_group.my_rank} Execution time: {execution_time} us, "
|
||||
f"data size: {result_buf.nbytes} bytes data type: {dtype().dtype.name} "
|
||||
f"data size: {result_buf.nbytes} bytes data type: {dtype_str} "
|
||||
f"packet type: {packet_type}"
|
||||
)
|
||||
executor = None
|
||||
@@ -243,7 +248,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-path", "--execution_plan_path", type=str, required=True)
|
||||
parser.add_argument("--size", type=str, required=True)
|
||||
parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation")
|
||||
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32")
|
||||
parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, bfloat16, float32, int32")
|
||||
parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16")
|
||||
parser.add_argument("--n_iters", type=int, default=10)
|
||||
parser.add_argument("--n_graph_iters", type=int, default=10)
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
#include <assert.h>
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#else
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
@@ -30,6 +32,7 @@ static __device__ unsigned int ranqd1(unsigned int seed) {
|
||||
} \
|
||||
}
|
||||
|
||||
FILL_DATA(bfloat16, __nv_bfloat16)
|
||||
FILL_DATA(float16, __half)
|
||||
FILL_DATA(float32, float)
|
||||
FILL_DATA(int32, int)
|
||||
@@ -48,11 +51,12 @@ FILL_DATA(int32, int)
|
||||
} \
|
||||
}
|
||||
|
||||
TEST_DATA_ALL_GATHER(bfloat16, __nv_bfloat16)
|
||||
TEST_DATA_ALL_GATHER(float16, __half)
|
||||
TEST_DATA_ALL_GATHER(float32, float)
|
||||
TEST_DATA_ALL_GATHER(int32, int)
|
||||
|
||||
#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType) \
|
||||
#define TEST_DATA_ALL_REDUCE(FuncNameType, DataType, Eps) \
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_reduce_##FuncNameType( \
|
||||
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
|
||||
for (int rank = 0; rank < num_ranks; rank++) { \
|
||||
@@ -66,15 +70,19 @@ TEST_DATA_ALL_GATHER(int32, int)
|
||||
} \
|
||||
} \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
|
||||
assert(abs(float(result_buf[i]) - float(test_buf[i])) < 1e-3 * num_ranks); \
|
||||
float expected = float(test_buf[i]); \
|
||||
float result = float(result_buf[i]); \
|
||||
float tol = Eps * num_ranks * (1.0f + abs(expected)); \
|
||||
assert(abs(result - expected) <= tol); \
|
||||
} \
|
||||
}
|
||||
|
||||
TEST_DATA_ALL_REDUCE(float16, __half)
|
||||
TEST_DATA_ALL_REDUCE(float32, float)
|
||||
TEST_DATA_ALL_REDUCE(int32, int)
|
||||
TEST_DATA_ALL_REDUCE(bfloat16, __nv_bfloat16, 7.8125e-3f)
|
||||
TEST_DATA_ALL_REDUCE(float16, __half, 9.765625e-4f)
|
||||
TEST_DATA_ALL_REDUCE(float32, float, 1.1920929e-7f)
|
||||
TEST_DATA_ALL_REDUCE(int32, int, 0.0f)
|
||||
|
||||
#define TEST_DATA_REDUCE_SCATTER(FuncNameType, DataType) \
|
||||
#define TEST_DATA_REDUCE_SCATTER(FuncNameType, DataType, Eps) \
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_reduce_scatter_##FuncNameType( \
|
||||
DataType* result_buf, DataType* test_buf, size_t num_elems, int num_ranks, int my_rank, int seq) { \
|
||||
int nem_elems_per_rank = num_elems / num_ranks; \
|
||||
@@ -91,14 +99,18 @@ TEST_DATA_ALL_REDUCE(int32, int)
|
||||
} \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_elems; i += blockDim.x * gridDim.x) { \
|
||||
if (i >= offset && i < offset + nem_elems_per_rank) { \
|
||||
assert(abs(float(result_buf[i - offset]) - float(test_buf[i])) < 1e-3 * num_ranks); \
|
||||
float expected = float(test_buf[i]); \
|
||||
float result = float(result_buf[i - offset]); \
|
||||
float tol = Eps * num_ranks * (1.0f + abs(expected)); \
|
||||
assert(abs(result - expected) <= tol); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
TEST_DATA_REDUCE_SCATTER(float16, __half)
|
||||
TEST_DATA_REDUCE_SCATTER(float32, float)
|
||||
TEST_DATA_REDUCE_SCATTER(int32, int)
|
||||
TEST_DATA_REDUCE_SCATTER(bfloat16, __nv_bfloat16, 7.8125e-3f)
|
||||
TEST_DATA_REDUCE_SCATTER(float16, __half, 9.765625e-4f)
|
||||
TEST_DATA_REDUCE_SCATTER(float32, float, 1.1920929e-7f)
|
||||
TEST_DATA_REDUCE_SCATTER(int32, int, 0.0f)
|
||||
|
||||
#define TEST_DATA_ALL_TO_ALL(FuncNameType, DataType) \
|
||||
extern "C" __global__ void __launch_bounds__(1024, 1) test_data_all_to_all_##FuncNameType( \
|
||||
@@ -118,6 +130,7 @@ TEST_DATA_REDUCE_SCATTER(int32, int)
|
||||
} \
|
||||
}
|
||||
|
||||
TEST_DATA_ALL_TO_ALL(bfloat16, __nv_bfloat16)
|
||||
TEST_DATA_ALL_TO_ALL(float16, __half)
|
||||
TEST_DATA_ALL_TO_ALL(float32, float)
|
||||
TEST_DATA_ALL_TO_ALL(int32, int)
|
||||
@@ -21,6 +21,13 @@ from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
|
||||
# FP8 E4M3 (hardware) requires SM >= 89 (Ada / Hopper) on NVIDIA GPUs.
|
||||
# On AMD/ROCm (e.g. MI300X), FP8 is supported natively — no skip needed.
|
||||
_is_hip = hasattr(cp.cuda.runtime, "is_hip") and cp.cuda.runtime.is_hip
|
||||
_gcn_arch_name = ""
|
||||
if _is_hip:
|
||||
_gcn_arch_name = cp.cuda.runtime.getDeviceProperties(0).get("gcnArchName", b"")
|
||||
if isinstance(_gcn_arch_name, bytes):
|
||||
_gcn_arch_name = _gcn_arch_name.decode()
|
||||
_gcn_arch_name = _gcn_arch_name.split(":", maxsplit=1)[0]
|
||||
_is_cdna4 = _gcn_arch_name.startswith("gfx95")
|
||||
_skip_fp8 = not _is_hip and int(cp.cuda.Device().compute_capability) < 89
|
||||
pytestmark = pytest.mark.skipif(_skip_fp8, reason="FP8 accum tests require SM >= 89 on CUDA")
|
||||
|
||||
@@ -90,7 +97,78 @@ def float_to_e4m3fn(f32_array, chunk_size=65536):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, max=0.9375, NaN = exp==15 or bits==0x80)
|
||||
# FP8 E4M3FNUZ helpers (AMD/ROCm; bias=8, max=240, NaN = bits==0x80, no -0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def e4m3fnuz_to_float(uint8_array):
|
||||
"""Decode a cupy uint8 array of E4M3FNUZ bit patterns to float32."""
|
||||
bits = uint8_array.astype(cp.int32)
|
||||
sign = (bits >> 7) & 1
|
||||
exp = (bits >> 3) & 0xF
|
||||
mant = bits & 0x7
|
||||
|
||||
# Normal: (-1)^s * 2^(exp-8) * (1 + mant/8)
|
||||
normal_val = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), (exp - 8).astype(cp.int32))
|
||||
# Subnormal (exp==0): (-1)^s * 2^(-7) * (mant/8)
|
||||
subnormal_val = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7))
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero is only 0x00; the 0x80 encoding is reserved for NaN under fnuz.
|
||||
result = cp.where(uint8_array.astype(cp.int32) == 0, cp.float32(0.0), result)
|
||||
nan_mask = uint8_array.astype(cp.int32) == 0x80
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
def float_to_e4m3fnuz(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3FNUZ bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn but using the fnuz table.
|
||||
"""
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3fnuz_to_float(all_bytes)
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -240.0, 240.0)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
absval_flat = absval.ravel()
|
||||
result_flat = result.ravel()
|
||||
|
||||
for start in range(0, n, chunk_size):
|
||||
end = min(start + chunk_size, n)
|
||||
chunk = absval_flat[start:end]
|
||||
diffs = cp.abs(chunk[:, None] - all_floats[None, :])
|
||||
result_flat[start:end] = cp.argmin(diffs, axis=1).astype(cp.uint8)
|
||||
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# 0x80 is NaN under fnuz (no negative zero). Collapse any encoding that
|
||||
# landed on 0x80 (small negatives quantised to zero magnitude) to 0x00.
|
||||
result = cp.where(result == 0x80, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
# Platform-aware E4M3 native helpers: ROCm CDNA4 and CUDA use OCP fn; older ROCm uses fnuz.
|
||||
if _is_hip and not _is_cdna4:
|
||||
e4m3_native_to_float = e4m3fnuz_to_float
|
||||
float_to_e4m3_native = float_to_e4m3fnuz
|
||||
fp8_native_dtype = DataType.float8_e4m3fnuz
|
||||
else:
|
||||
e4m3_native_to_float = e4m3fn_to_float
|
||||
float_to_e4m3_native = float_to_e4m3fn
|
||||
fp8_native_dtype = DataType.float8_e4m3fn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN)
|
||||
# Matches Triton's fp8e4b15: all 256 bit patterns are finite.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -108,11 +186,6 @@ def e4m3b15_to_float(uint8_array):
|
||||
|
||||
result = cp.where(exp == 0, subnormal_val, normal_val)
|
||||
result = cp.where(sign == 1, -result, result)
|
||||
# Zero
|
||||
result = cp.where((exp == 0) & (mant == 0), cp.float32(0.0), result)
|
||||
# NaN: exp==15 or negative zero (0x80)
|
||||
nan_mask = (exp == 15) | (uint8_array.astype(cp.int32) == 0x80)
|
||||
result = cp.where(nan_mask, cp.float32(float("nan")), result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -120,18 +193,17 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
"""Encode a cupy float32 array to uint8 E4M3B15 bit patterns.
|
||||
|
||||
Same lookup-table approach as float_to_e4m3fn.
|
||||
Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15.
|
||||
"""
|
||||
# Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F)
|
||||
all_bytes = cp.arange(128, dtype=cp.uint8)
|
||||
all_floats = e4m3b15_to_float(all_bytes) # (128,) float32
|
||||
# Mark NaN entries as inf so they're never selected as nearest
|
||||
all_floats = cp.where(cp.isnan(all_floats), cp.float32(float("inf")), all_floats)
|
||||
|
||||
# Clamp input and extract sign
|
||||
clamped = f32_array.astype(cp.float32)
|
||||
clamped = cp.clip(clamped, -0.9375, 0.9375)
|
||||
signs = (clamped < 0).astype(cp.uint8)
|
||||
absval = cp.abs(clamped)
|
||||
# Clamp input and extract sign.
|
||||
values = f32_array.astype(cp.float32)
|
||||
signs = cp.signbit(values).astype(cp.uint8)
|
||||
absval = cp.abs(values)
|
||||
absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75))
|
||||
|
||||
result = cp.zeros(absval.shape, dtype=cp.uint8)
|
||||
n = absval.size
|
||||
@@ -148,8 +220,6 @@ def float_to_e4m3b15(f32_array, chunk_size=65536):
|
||||
# Combine with sign bit
|
||||
result = result_flat.reshape(absval.shape)
|
||||
result = result | (signs << 7)
|
||||
# Handle exact zero
|
||||
result = cp.where(absval == 0, cp.uint8(0), result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -226,12 +296,6 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
|
||||
buf = GpuBuffer(size, dtype=cp.uint8)
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", DataType.float8_e4m3),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
# rsag_zero_copy and fullmesh need explicit block/thread counts
|
||||
if "rsag" in algo_name:
|
||||
nb = max(1, min(32, size // (world_size * 32)))
|
||||
@@ -243,13 +307,19 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
nb = 0
|
||||
nt = 0
|
||||
|
||||
accum_configs = [
|
||||
("fp8_native", fp8_native_dtype),
|
||||
("float16", DataType.float16),
|
||||
("float32", DataType.float32),
|
||||
]
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank data (use numpy to avoid hipRAND issues on ROCm)
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
src_f32 = cp.asarray(rng.randn(size).astype(np.float32))
|
||||
src_f32 = cp.clip(src_f32, -240.0, 240.0)
|
||||
src_fp8 = float_to_e4m3fn(src_f32)
|
||||
src_fp8 = float_to_e4m3_native(src_f32)
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_fp8
|
||||
@@ -260,12 +330,12 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
algo,
|
||||
comm_group,
|
||||
buf,
|
||||
dtype=DataType.float8_e4m3,
|
||||
dtype=fp8_native_dtype,
|
||||
accum_dtype=accum_dtype,
|
||||
nblocks=nb,
|
||||
nthreads_per_block=nt,
|
||||
)
|
||||
result_f32 = e4m3fn_to_float(result)
|
||||
result_f32 = e4m3_native_to_float(result)
|
||||
|
||||
# Compute float32 reference: sum all ranks' quantized FP8 inputs in float32
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
@@ -273,12 +343,13 @@ def test_fp8_e4m3_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
rank_data = cp.asarray(rng_r.randn(size).astype(np.float32))
|
||||
rank_data = cp.clip(rank_data, -240.0, 240.0)
|
||||
rank_data_fp8 = float_to_e4m3fn(rank_data)
|
||||
ref_f32 += e4m3fn_to_float(rank_data_fp8)
|
||||
rank_data_fp8 = float_to_e4m3_native(rank_data)
|
||||
ref_f32 += e4m3_native_to_float(rank_data_fp8)
|
||||
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
# Compute errors (only on valid, non-NaN entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
# Reset between runs
|
||||
@@ -341,13 +412,10 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
|
||||
errors = {}
|
||||
for accum_label, accum_dtype in accum_configs:
|
||||
# Generate deterministic per-rank random uint8 values in valid e4m3b15 range
|
||||
# Generate deterministic per-rank random uint8 values covering the full e4m3b15 range.
|
||||
# All 256 bit patterns are valid (no NaN in this format).
|
||||
rng = np.random.RandomState(42 + rank)
|
||||
raw = cp.asarray(rng.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs = cp.asarray(rng.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
src_uint8 = raw | signs
|
||||
# Fix negative zero -> positive zero
|
||||
src_uint8 = cp.where(src_uint8 == 0x80, cp.uint8(0), src_uint8)
|
||||
src_uint8 = cp.asarray(rng.randint(0, 256, (size,)).astype(np.uint8))
|
||||
|
||||
# Copy into symmetric buffer
|
||||
buf[:] = src_uint8
|
||||
@@ -371,19 +439,15 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int):
|
||||
ref_f32 = cp.zeros(size, dtype=cp.float32)
|
||||
for r in range(world_size):
|
||||
rng_r = np.random.RandomState(42 + r)
|
||||
raw_r = cp.asarray(rng_r.randint(0, 0x78, (size,)).astype(np.uint8))
|
||||
signs_r = cp.asarray(rng_r.randint(0, 2, (size,)).astype(np.uint8)) << 7
|
||||
bits_r = raw_r | signs_r
|
||||
bits_r = cp.where(bits_r == 0x80, cp.uint8(0), bits_r)
|
||||
bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8))
|
||||
ref_f32 += e4m3b15_to_float(bits_r)
|
||||
|
||||
# Clamp reference to e4m3b15 representable range
|
||||
ref_f32 = cp.clip(ref_f32, -0.9375, 0.9375)
|
||||
ref_f32 = cp.clip(ref_f32, -1.75, 1.75)
|
||||
|
||||
# Compute errors (only on valid entries)
|
||||
valid = ~cp.isnan(result_f32) & ~cp.isnan(ref_f32)
|
||||
abs_err = cp.abs(result_f32[valid] - ref_f32[valid])
|
||||
mean_abs_err = float(cp.mean(abs_err)) if abs_err.size > 0 else 0.0
|
||||
# Compute errors
|
||||
abs_err = cp.abs(result_f32 - ref_f32)
|
||||
mean_abs_err = float(cp.mean(abs_err))
|
||||
errors[accum_label] = mean_abs_err
|
||||
|
||||
algo.reset()
|
||||
|
||||
Reference in New Issue
Block a user