mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Add NPKit GPU event support (#310)
This commit is contained in:
@@ -79,3 +79,85 @@ jobs:
|
||||
export PATH=/usr/local/mpi/bin:$PATH
|
||||
mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- job: UnitTestWithNpKit
|
||||
timeoutInMinutes: 30
|
||||
pool:
|
||||
name: mscclpp
|
||||
strategy:
|
||||
matrix:
|
||||
cuda11:
|
||||
containerImage: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda11.8
|
||||
cuda12:
|
||||
containerImage: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda12.2
|
||||
|
||||
container:
|
||||
image: $[ variables['containerImage'] ]
|
||||
options: --privileged --ipc=host --gpus=all --ulimit memlock=-1:-1
|
||||
|
||||
steps:
|
||||
- task: Bash@3
|
||||
name: Build
|
||||
displayName: Build
|
||||
inputs:
|
||||
targetType: 'inline'
|
||||
script: |
|
||||
mkdir build && cd build
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DNPKIT_FLAGS="-DENABLE_NPKIT -DENABLE_NPKIT_EVENT_TIME_SYNC_CPU -DENABLE_NPKIT_EVENT_TIME_SYNC_GPU -DENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY -DENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT -DENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY -DENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT" ..
|
||||
make -j
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- task: Bash@3
|
||||
name: LockGPUClock
|
||||
displayName: Lock GPU clock frequency
|
||||
inputs:
|
||||
targetType: 'inline'
|
||||
script: |
|
||||
sudo nvidia-smi -pm 1
|
||||
for i in $(seq 0 $(( $(nvidia-smi -L | wc -l) - 1 ))); do
|
||||
sudo nvidia-smi -ac $(nvidia-smi --query-gpu=clocks.max.memory,clocks.max.sm --format=csv,noheader,nounits -i $i | sed 's/\ //') -i $i
|
||||
done
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- task: Bash@3
|
||||
name: MpUnitTests
|
||||
displayName: Run mscclpp multi-process unit tests
|
||||
inputs:
|
||||
targetType: 'inline'
|
||||
script: |
|
||||
set -e
|
||||
rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output
|
||||
export PATH=/usr/local/mpi/bin:$PATH
|
||||
export NPKIT_DUMP_DIR=./npkit_dump
|
||||
mpirun -tag-output -np 2 ./build/test/mp_unit_tests --gtest_filter="ExecutorTest.TwoNodesAllreduce"
|
||||
python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
- task: Bash@3
|
||||
name: PyTests
|
||||
displayName: Run pytests
|
||||
inputs:
|
||||
targetType: 'inline'
|
||||
script: |
|
||||
set -e
|
||||
rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output
|
||||
export PATH=/usr/local/mpi/bin:$PATH
|
||||
export NPKIT_DUMP_DIR=./npkit_dump
|
||||
mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce.json'
|
||||
python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output
|
||||
mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json'
|
||||
python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output
|
||||
grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json
|
||||
workingDirectory: '$(System.DefaultWorkingDirectory)'
|
||||
|
||||
@@ -15,7 +15,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||
|
||||
# Options
|
||||
option(ENABLE_TRACE "Enable tracing" OFF)
|
||||
option(USE_NPKIT "Use NPKIT" ON)
|
||||
option(BUILD_TESTS "Build tests" ON)
|
||||
option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON)
|
||||
option(USE_CUDA "Use NVIDIA/CUDA." OFF)
|
||||
@@ -119,8 +118,8 @@ endif()
|
||||
if(ENABLE_TRACE)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_TRACE)
|
||||
endif()
|
||||
if(USE_NPKIT)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE ENABLE_NPKIT)
|
||||
if(NPKIT_FLAGS)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE ${NPKIT_FLAGS})
|
||||
endif()
|
||||
|
||||
# libmscclpp
|
||||
|
||||
97
include/mscclpp/npkit/npkit.hpp
Normal file
97
include/mscclpp/npkit/npkit.hpp
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef NPKIT_H_
|
||||
#define NPKIT_H_
|
||||
|
||||
#include <mscclpp/device.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/npkit/npkit_event.hpp>
|
||||
#include <mscclpp/npkit/npkit_struct.hpp>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#define NPKIT_GET_GPU_TIMESTAMP wall_clock64
|
||||
#else
|
||||
#define NPKIT_GET_GPU_TIMESTAMP clock64
|
||||
#endif
|
||||
|
||||
#define NPKIT_SHM_NUM_EVENTS 64
|
||||
|
||||
class NpKit {
|
||||
public:
|
||||
static const uint64_t kNumGpuEventBuffers = 1024;
|
||||
|
||||
static const uint64_t kNumCpuEventBuffers = 64;
|
||||
|
||||
static void Init(int rank);
|
||||
|
||||
static void Dump(const std::string& dump_dir);
|
||||
|
||||
static void Shutdown();
|
||||
|
||||
static NpKitEventCollectContext* GetGpuEventCollectContexts();
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
static MSCCLPP_DEVICE_INLINE void CollectGpuEventShm(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp,
|
||||
NpKitEvent* event_buffer, uint64_t* event_buffer_head) {
|
||||
if (*event_buffer_head < NPKIT_SHM_NUM_EVENTS) {
|
||||
if (threadIdx.x == 0) {
|
||||
NpKitEvent& event = event_buffer[*event_buffer_head];
|
||||
event.fields.type = type;
|
||||
event.fields.size = size;
|
||||
event.fields.rsvd = rsvd;
|
||||
event.fields.timestamp = timestamp;
|
||||
}
|
||||
(*event_buffer_head)++;
|
||||
}
|
||||
}
|
||||
|
||||
static MSCCLPP_DEVICE_INLINE void StoreGpuEventShm(NpKitEventCollectContext* npKitEventCollectContexts,
|
||||
NpKitEvent* event_buffer, uint64_t event_buffer_head) {
|
||||
#if defined(MSCCLPP_DEVICE_HIP)
|
||||
__synclds();
|
||||
#else // !defined(MSCCLPP_DEVICE_HIP)
|
||||
__syncthreads();
|
||||
#endif // !defined(MSCCLPP_DEVICE_HIP)
|
||||
NpKitEventCollectContext* npKitCtx = npKitEventCollectContexts + blockIdx.x;
|
||||
NpKitEvent* global_event_buffer = npKitCtx->event_buffer;
|
||||
uint64_t global_event_buffer_head = npKitCtx->event_buffer_head;
|
||||
for (size_t i = threadIdx.x; i < event_buffer_head * sizeof(NpKitEvent) / sizeof(int4); i += blockDim.x) {
|
||||
((int4*)(global_event_buffer + global_event_buffer_head))[i] = ((int4*)event_buffer)[i];
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
npKitCtx->event_buffer_head += event_buffer_head;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id);
|
||||
|
||||
static uint64_t* GetCpuTimestamp();
|
||||
|
||||
private:
|
||||
static void CpuTimestampUpdateThread();
|
||||
|
||||
// 64K * 1024 * 16B = 1GB per GPU
|
||||
static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16;
|
||||
|
||||
// 64K * 2 (send/recv) * (1024/64) = 2M, 2M * 64 * 16B = 2GB per CPU
|
||||
static const uint64_t kMaxNumCpuEventsPerBuffer = 1ULL << 21;
|
||||
|
||||
static std::vector<mscclpp::UniqueCudaPtr<NpKitEvent>> gpu_event_buffers_;
|
||||
static std::vector<std::unique_ptr<NpKitEvent[]>> cpu_event_buffers_;
|
||||
|
||||
static mscclpp::UniqueCudaPtr<NpKitEventCollectContext> gpu_collect_contexts_;
|
||||
static std::unique_ptr<NpKitEventCollectContext[]> cpu_collect_contexts_;
|
||||
|
||||
static uint64_t rank_;
|
||||
|
||||
static mscclpp::UniqueCudaHostPtr<uint64_t> cpu_timestamp_;
|
||||
static std::unique_ptr<std::thread> cpu_timestamp_update_thread_;
|
||||
static volatile bool cpu_timestamp_update_thread_should_stop_;
|
||||
};
|
||||
|
||||
#endif
|
||||
18
include/mscclpp/npkit/npkit_event.hpp
Normal file
18
include/mscclpp/npkit/npkit_event.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef NPKIT_EVENT_H_
|
||||
#define NPKIT_EVENT_H_
|
||||
|
||||
#define NPKIT_EVENT_INVALID 0x0
|
||||
|
||||
#define NPKIT_EVENT_TIME_SYNC_GPU 0x1
|
||||
#define NPKIT_EVENT_TIME_SYNC_CPU 0x2
|
||||
|
||||
#define NPKIT_EVENT_EXECUTOR_INIT_ENTRY 0x3
|
||||
#define NPKIT_EVENT_EXECUTOR_INIT_EXIT 0x4
|
||||
|
||||
#define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x5
|
||||
#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x15
|
||||
|
||||
#endif
|
||||
@@ -25,4 +25,4 @@ struct NpKitEventCollectContext {
|
||||
|
||||
#pragma pack(pop)
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -25,6 +25,7 @@ from ._mscclpp import (
|
||||
PacketType,
|
||||
version,
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
)
|
||||
|
||||
__version__ = version()
|
||||
|
||||
@@ -22,6 +22,7 @@ extern void register_utils(nb::module_& m);
|
||||
extern void register_numa(nb::module_& m);
|
||||
extern void register_nvls(nb::module_& m);
|
||||
extern void register_executor(nb::module_& m);
|
||||
extern void register_npkit(nb::module_& m);
|
||||
|
||||
template <typename T>
|
||||
void def_nonblocking_future(nb::handle& m, const std::string& typestr) {
|
||||
@@ -189,4 +190,5 @@ NB_MODULE(_mscclpp, m) {
|
||||
register_numa(m);
|
||||
register_nvls(m);
|
||||
register_executor(m);
|
||||
register_npkit(m);
|
||||
}
|
||||
|
||||
16
python/mscclpp/npkit_py.cpp
Normal file
16
python/mscclpp/npkit_py.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
void register_npkit(nb::module_ &m) {
|
||||
nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions");
|
||||
sub_m.def("init", &NpKit::Init);
|
||||
sub_m.def("dump", &NpKit::Dump);
|
||||
sub_m.def("shutdown", &NpKit::Shutdown);
|
||||
}
|
||||
@@ -7,6 +7,7 @@ from mscclpp import (
|
||||
Executor,
|
||||
ExecutionPlan,
|
||||
PacketType,
|
||||
npkit,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
|
||||
@@ -87,6 +88,9 @@ def main(
|
||||
mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD)
|
||||
cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use()
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR")
|
||||
if npkit_dump_dir is not None:
|
||||
npkit.init(mscclpp_group.my_rank)
|
||||
execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path)
|
||||
|
||||
cp.random.seed(seed)
|
||||
@@ -119,6 +123,9 @@ def main(
|
||||
|
||||
mscclpp_group.barrier()
|
||||
execution_time = bench_time(100, 10, executor_func)
|
||||
if npkit_dump_dir is not None:
|
||||
npkit.dump(npkit_dump_dir)
|
||||
npkit.shutdown()
|
||||
print(
|
||||
f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, "
|
||||
f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} "
|
||||
|
||||
@@ -24,6 +24,7 @@ from mscclpp import (
|
||||
TcpBootstrap,
|
||||
Transport,
|
||||
is_nvls_supported,
|
||||
npkit,
|
||||
)
|
||||
import mscclpp.comm as mscclpp_comm
|
||||
from mscclpp.utils import KernelBuilder, pack
|
||||
@@ -603,6 +604,9 @@ def test_executor(mpi_group: MpiGroup, filename: str):
|
||||
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
executor = Executor(mscclpp_group.communicator)
|
||||
npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR")
|
||||
if npkit_dump_dir is not None:
|
||||
npkit.init(mscclpp_group.my_rank)
|
||||
execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename))
|
||||
|
||||
nelems = 1024 * 1024
|
||||
@@ -629,3 +633,6 @@ def test_executor(mpi_group: MpiGroup, filename: str):
|
||||
)
|
||||
stream.synchronize()
|
||||
assert cp.allclose(sendbuf, expected, atol=1e-3 * mpi_group.comm.size)
|
||||
if npkit_dump_dir is not None:
|
||||
npkit.dump(npkit_dump_dir)
|
||||
npkit.shutdown()
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
#include "connection.hpp"
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
#endif
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
@@ -10,7 +13,6 @@
|
||||
#include "debug.h"
|
||||
#include "endpoint.hpp"
|
||||
#include "infiniband/verbs.h"
|
||||
#include "npkit/npkit.h"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
|
||||
@@ -13,19 +13,43 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
switch (dataType) {
|
||||
case DataType::INT32:
|
||||
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag);
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::UINT32:
|
||||
executionKernel<uint32_t><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag);
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT16:
|
||||
executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag);
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT32:
|
||||
executionKernel<float><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag);
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,7 +261,11 @@ struct Executor::Impl {
|
||||
DataType dataType, cudaStream_t stream, PacketType packetType) {
|
||||
static uint32_t flag = 0;
|
||||
int nthreadblocks = context.deviceExecutionPlans.size();
|
||||
#if defined(ENABLE_NPKIT)
|
||||
size_t sharedMemSize = sizeof(DeviceExecutionPlan) + NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
|
||||
#else
|
||||
size_t sharedMemSize = sizeof(DeviceExecutionPlan);
|
||||
#endif
|
||||
switch (packetType) {
|
||||
case PacketType::LL16:
|
||||
ExecutionKernel::launchKernel<LL16Packet>(
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
#define MSCCLPP_EXECUTION_KERNEL_HPP_
|
||||
|
||||
#include <mscclpp/executor.hpp>
|
||||
#if defined(ENABLE_NPKIT)
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
#endif
|
||||
#include <mscclpp/packet_device.hpp>
|
||||
#include <mscclpp/proxy_channel.hpp>
|
||||
#include <mscclpp/sm_channel.hpp>
|
||||
@@ -333,10 +336,26 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
|
||||
|
||||
template <typename T, typename PacketType = LL16Packet>
|
||||
__global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch,
|
||||
size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag) {
|
||||
size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
|
||||
#else
|
||||
) {
|
||||
#endif
|
||||
extern __shared__ int4 sharedMem[];
|
||||
int bid = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
#if defined(ENABLE_NPKIT)
|
||||
NpKitEvent* event_buffer = (NpKitEvent*)((char*)sharedMem + sizeof(DeviceExecutionPlan));
|
||||
uint64_t event_buffer_head = 0;
|
||||
#if defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT)
|
||||
uint64_t npkit_timestamp_entry = 0;
|
||||
if (tid == 0) {
|
||||
npkit_timestamp_entry = NPKIT_GET_GPU_TIMESTAMP();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
DeviceExecutionPlan* localPlan = plan + bid;
|
||||
for (size_t i = tid; i < sizeof(DeviceExecutionPlan) / sizeof(int4); i += blockDim.x) {
|
||||
sharedMem[i] = ((int4*)localPlan)[i];
|
||||
@@ -352,8 +371,31 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
|
||||
DeviceHandle<SmChannel>* smChannels = localPlan->channels.smChannels;
|
||||
DeviceHandle<SimpleProxyChannel>* proxyChannels = localPlan->channels.proxyChannels;
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp, event_buffer, &event_buffer_head);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), event_buffer,
|
||||
&event_buffer_head);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY) && \
|
||||
defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_INIT_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
|
||||
&event_buffer_head);
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_INIT_EXIT, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), event_buffer,
|
||||
&event_buffer_head);
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < nOperations; i++) {
|
||||
Operation& op = operations[i];
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY + (int)op.type, op.size, 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
event_buffer, &event_buffer_head);
|
||||
#endif
|
||||
|
||||
if (op.type == OperationType::BARRIER) {
|
||||
__syncthreads();
|
||||
} else if (op.type == OperationType::SIGNAL) {
|
||||
@@ -403,7 +445,16 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu
|
||||
handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, smChannels, op.outputChannelIndexes,
|
||||
op.outputOffsets, op.nOutputs, op.size);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT)
|
||||
NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT + (int)op.type, op.size, 0, NPKIT_GET_GPU_TIMESTAMP(),
|
||||
event_buffer, &event_buffer_head);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
|
||||
#endif
|
||||
}
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
@@ -417,19 +468,43 @@ class ExecutionKernel {
|
||||
switch (dataType) {
|
||||
case DataType::INT32:
|
||||
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag);
|
||||
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::UINT32:
|
||||
executionKernel<uint32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag);
|
||||
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT16:
|
||||
executionKernel<half, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag);
|
||||
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::FLOAT32:
|
||||
executionKernel<float, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag);
|
||||
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "npkit.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
uint64_t NpKit::rank_ = 0;
|
||||
|
||||
@@ -16,41 +17,85 @@ std::vector<std::unique_ptr<NpKitEvent[]>> NpKit::cpu_event_buffers_;
|
||||
|
||||
mscclpp::UniqueCudaPtr<NpKitEventCollectContext> NpKit::gpu_collect_contexts_;
|
||||
std::unique_ptr<NpKitEventCollectContext[]> NpKit::cpu_collect_contexts_;
|
||||
uint64_t NpKit::cpu_base_system_timestamp_ = 0;
|
||||
uint64_t NpKit::cpu_base_steady_timestamp_ = 0;
|
||||
|
||||
mscclpp::UniqueCudaHostPtr<uint64_t> NpKit::cpu_timestamp_;
|
||||
std::unique_ptr<std::thread> NpKit::cpu_timestamp_update_thread_;
|
||||
volatile bool NpKit::cpu_timestamp_update_thread_should_stop_ = false;
|
||||
|
||||
void NpKit::CpuTimestampUpdateThread() {
|
||||
uint64_t init_system_clock = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
uint64_t init_steady_clock = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
uint64_t curr_steady_clock = 0;
|
||||
volatile uint64_t* volatile_cpu_timestamp_ = cpu_timestamp_.get();
|
||||
while (!cpu_timestamp_update_thread_should_stop_) {
|
||||
curr_steady_clock = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
*volatile_cpu_timestamp_ = init_system_clock + (curr_steady_clock - init_steady_clock);
|
||||
}
|
||||
}
|
||||
|
||||
void NpKit::Init(int rank) {
|
||||
#if defined(ENABLE_NPKIT)
|
||||
uint64_t i = 0;
|
||||
NpKitEventCollectContext ctx;
|
||||
ctx.event_buffer_head = 0;
|
||||
rank_ = rank;
|
||||
|
||||
// Init event data structures
|
||||
gpu_collect_contexts_ = mscclpp::allocUniqueCuda<NpKitEventCollectContext>(kNumGpuEventBuffers);
|
||||
for (i = 0; i < kNumGpuEventBuffers; i++) {
|
||||
gpu_collect_contexts_ = mscclpp::allocUniqueCuda<NpKitEventCollectContext>(NpKit::kNumGpuEventBuffers);
|
||||
for (i = 0; i < NpKit::kNumGpuEventBuffers; i++) {
|
||||
gpu_event_buffers_.emplace_back(mscclpp::allocUniqueCuda<NpKitEvent>(kMaxNumGpuEventsPerBuffer));
|
||||
ctx.event_buffer = gpu_event_buffers_[i].get();
|
||||
mscclpp::memcpyCuda(gpu_collect_contexts_.get() + i, &ctx, 1);
|
||||
}
|
||||
|
||||
cpu_collect_contexts_ = std::make_unique<NpKitEventCollectContext[]>(kNumCpuEventBuffers);
|
||||
for (i = 0; i < kNumCpuEventBuffers; i++) {
|
||||
cpu_collect_contexts_ = std::make_unique<NpKitEventCollectContext[]>(NpKit::kNumCpuEventBuffers);
|
||||
for (i = 0; i < NpKit::kNumCpuEventBuffers; i++) {
|
||||
cpu_event_buffers_.emplace_back(std::make_unique<NpKitEvent[]>(kMaxNumCpuEventsPerBuffer));
|
||||
ctx.event_buffer = cpu_event_buffers_[i].get();
|
||||
cpu_collect_contexts_[i] = ctx;
|
||||
}
|
||||
|
||||
// Init timestamp
|
||||
cpu_base_system_timestamp_ = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
cpu_base_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
cpu_timestamp_ = mscclpp::makeUniqueCudaHost<uint64_t>();
|
||||
volatile uint64_t* volatile_cpu_timestamp = cpu_timestamp_.get();
|
||||
*volatile_cpu_timestamp = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
cpu_timestamp_update_thread_should_stop_ = false;
|
||||
cpu_timestamp_update_thread_ = std::make_unique<std::thread>(CpuTimestampUpdateThread);
|
||||
#else
|
||||
WARN("NpKit::Init(%d) : MSCCLPP library was not built with NPKit enabled.", rank);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
static int GetGpuClockRateInKhz() {
|
||||
int dev_id;
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
cudaDeviceProp_t dev_prop;
|
||||
char gcn_arch[256];
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&dev_id));
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev_id));
|
||||
char* gcnArchNameToken = strtok(dev_prop.gcnArchName, ":");
|
||||
strcpy(gcn_arch, gcnArchNameToken);
|
||||
if (strncmp("gfx94", gcn_arch, 5) == 0)
|
||||
return 100000;
|
||||
else
|
||||
return 25000;
|
||||
#else
|
||||
cudaDeviceProp dev_prop;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&dev_id));
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev_id));
|
||||
return dev_prop.clockRate;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
void NpKit::Dump(const std::string& dump_dir) {
|
||||
#if defined(ENABLE_NPKIT)
|
||||
uint64_t i = 0;
|
||||
std::string dump_file_path;
|
||||
|
||||
// Dump CPU events
|
||||
for (i = 0; i < kNumCpuEventBuffers; i++) {
|
||||
for (i = 0; i < NpKit::kNumCpuEventBuffers; i++) {
|
||||
dump_file_path = dump_dir;
|
||||
dump_file_path += "/cpu_events_rank_";
|
||||
dump_file_path += std::to_string(rank_);
|
||||
@@ -80,7 +125,7 @@ void NpKit::Dump(const std::string& dump_dir) {
|
||||
clock_period_den_file.close();
|
||||
|
||||
// Dump GPU events, reuse CPU struct
|
||||
for (i = 0; i < kNumGpuEventBuffers; i++) {
|
||||
for (i = 0; i < NpKit::kNumGpuEventBuffers; i++) {
|
||||
dump_file_path = dump_dir;
|
||||
dump_file_path += "/gpu_events_rank_";
|
||||
dump_file_path += std::to_string(rank_);
|
||||
@@ -98,17 +143,21 @@ void NpKit::Dump(const std::string& dump_dir) {
|
||||
dump_file_path = dump_dir;
|
||||
dump_file_path += "/gpu_clock_rate_rank_";
|
||||
dump_file_path += std::to_string(rank_);
|
||||
cudaDeviceProp dev_prop;
|
||||
int dev;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&dev));
|
||||
MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev));
|
||||
std::string clock_rate_str = std::to_string(dev_prop.clockRate);
|
||||
std::string clock_rate_str = std::to_string(GetGpuClockRateInKhz());
|
||||
auto gpu_clock_rate_file = std::fstream(dump_file_path, std::ios::out);
|
||||
gpu_clock_rate_file.write(clock_rate_str.c_str(), clock_rate_str.length());
|
||||
gpu_clock_rate_file.close();
|
||||
#else
|
||||
WARN("NpKit::Dump(%s) : MSCCLPP library was not built with NPKit enabled.", dump_dir.c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
void NpKit::Shutdown() {
|
||||
#if defined(ENABLE_NPKIT)
|
||||
// Stop CPU timestamp updating thread
|
||||
cpu_timestamp_update_thread_should_stop_ = true;
|
||||
cpu_timestamp_update_thread_->join();
|
||||
|
||||
// Free CPU event data structures
|
||||
cpu_event_buffers_.clear();
|
||||
cpu_collect_contexts_.reset();
|
||||
@@ -116,6 +165,11 @@ void NpKit::Shutdown() {
|
||||
// Free GPU event data structures
|
||||
gpu_event_buffers_.clear();
|
||||
gpu_collect_contexts_.reset();
|
||||
|
||||
// Free timestamp
|
||||
cpu_timestamp_update_thread_.reset();
|
||||
cpu_timestamp_.reset();
|
||||
#endif
|
||||
}
|
||||
|
||||
NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts() { return gpu_collect_contexts_.get(); }
|
||||
@@ -132,7 +186,4 @@ void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t NpKit::GetCpuTimestamp() {
|
||||
uint64_t cpu_curr_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
return cpu_base_steady_timestamp_ + (cpu_curr_steady_timestamp_ - cpu_base_steady_timestamp_);
|
||||
}
|
||||
uint64_t* NpKit::GetCpuTimestamp() { return cpu_timestamp_.get(); }
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef NPKIT_H_
|
||||
#define NPKIT_H_
|
||||
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "npkit_event.h"
|
||||
#include "npkit_struct.h"
|
||||
|
||||
class NpKit {
|
||||
public:
|
||||
static const uint64_t kNumGpuEventBuffers = 512;
|
||||
|
||||
static const uint64_t kNumCpuEventBuffers = 32;
|
||||
|
||||
static void Init(int rank);
|
||||
|
||||
static void Dump(const std::string& dump_dir);
|
||||
|
||||
static void Shutdown();
|
||||
|
||||
static NpKitEventCollectContext* GetGpuEventCollectContexts();
|
||||
|
||||
#ifdef __CUDACC__
|
||||
static inline __device__ void CollectGpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp,
|
||||
NpKitEventCollectContext* ctx) {
|
||||
uint64_t event_buffer_head = ctx->event_buffer_head;
|
||||
if (event_buffer_head < kMaxNumGpuEventsPerBuffer) {
|
||||
NpKitEvent& event = ctx->event_buffer[event_buffer_head];
|
||||
event.fields.type = type;
|
||||
event.fields.size = size;
|
||||
event.fields.rsvd = rsvd;
|
||||
event.fields.timestamp = timestamp;
|
||||
ctx->event_buffer_head++;
|
||||
}
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
|
||||
static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id);
|
||||
|
||||
static uint64_t GetCpuTimestamp();
|
||||
|
||||
private:
|
||||
// 64K * 512 * 16B = 512MB per GPU
|
||||
static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16;
|
||||
|
||||
// 64K * 2 (send/recv) * (512/32) = 2M, 2M * 32 * 16B = 1GB per CPU
|
||||
static const uint64_t kMaxNumCpuEventsPerBuffer = 1ULL << 21;
|
||||
|
||||
static std::vector<mscclpp::UniqueCudaPtr<NpKitEvent>> gpu_event_buffers_;
|
||||
static std::vector<std::unique_ptr<NpKitEvent[]>> cpu_event_buffers_;
|
||||
|
||||
static mscclpp::UniqueCudaPtr<NpKitEventCollectContext> gpu_collect_contexts_;
|
||||
static std::unique_ptr<NpKitEventCollectContext[]> cpu_collect_contexts_;
|
||||
|
||||
static uint64_t cpu_base_system_timestamp_;
|
||||
static uint64_t cpu_base_steady_timestamp_;
|
||||
|
||||
static uint64_t rank_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#ifndef NPKIT_EVENT_H_
|
||||
#define NPKIT_EVENT_H_
|
||||
|
||||
#define NPKIT_EVENT_INVALID 0x0
|
||||
|
||||
#define NPKIT_EVENT_TIME_SYNC_GPU 0x1
|
||||
#define NPKIT_EVENT_TIME_SYNC_CPU 0x2
|
||||
|
||||
#define NPKIT_EVENT_SM_REDUCE_ENTRY 0x3
|
||||
#define NPKIT_EVENT_SM_REDUCE_EXIT 0x4
|
||||
|
||||
#define NPKIT_EVENT_IB_SEND_DATA_ENTRY 0x5
|
||||
#define NPKIT_EVENT_IB_SEND_FLAG_ENTRY 0x6
|
||||
#define NPKIT_EVENT_IB_SEND_EXIT 0x7
|
||||
|
||||
#define NPKIT_EVENT_DMA_SEND_DATA_ENTRY 0x8
|
||||
#define NPKIT_EVENT_DMA_SEND_FLAG_ENTRY 0x9
|
||||
#define NPKIT_EVENT_DMA_SEND_EXIT 0xA
|
||||
|
||||
#endif
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <mscclpp/executor.hpp>
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
#include <mscclpp/utils.hpp>
|
||||
#include <sstream>
|
||||
|
||||
@@ -74,11 +75,13 @@ double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::s
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 5) {
|
||||
if (argc != 7) {
|
||||
std::cerr << "Usage: " << argv[0] << " <buffer size>"
|
||||
<< " <execution plan name>"
|
||||
<< " <execution plan path>"
|
||||
<< " <nthreads per block>" << std::endl;
|
||||
<< " <nthreads per block>"
|
||||
<< " <number of iterations>"
|
||||
<< " <number of graph iterations>" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -93,6 +96,9 @@ int main(int argc, char* argv[]) {
|
||||
const std::string executionPlanName = argv[2];
|
||||
const std::string executionPlanPath = argv[3];
|
||||
const int nthreadsPerBlock = std::stoi(argv[4]);
|
||||
const int niters = std::stoi(argv[5]);
|
||||
const int ngraphIters = std::stoi(argv[6]);
|
||||
const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR");
|
||||
|
||||
std::shared_ptr<mscclpp::TcpBootstrap> bootstrap;
|
||||
mscclpp::UniqueId id;
|
||||
@@ -103,11 +109,22 @@ int main(int argc, char* argv[]) {
|
||||
std::shared_ptr<mscclpp::Communicator> communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
std::shared_ptr<mscclpp::Executor> executor = std::make_shared<mscclpp::Executor>(communicator);
|
||||
|
||||
if (npkitDumpDir != nullptr) {
|
||||
NpKit::Init(rank);
|
||||
}
|
||||
|
||||
mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath);
|
||||
std::shared_ptr<char> sendbuff = mscclpp::allocExtSharedCuda<char>(bufferSize);
|
||||
std::vector<int> dataHost(bufferSize / sizeof(int), rank);
|
||||
MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice));
|
||||
double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, 200, 20);
|
||||
double deltaSec =
|
||||
benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, niters, ngraphIters);
|
||||
|
||||
if (npkitDumpDir != nullptr) {
|
||||
NpKit::Dump(npkitDumpDir);
|
||||
NpKit::Shutdown();
|
||||
}
|
||||
|
||||
std::cout << "Rank " << rank << ": " << bufferSize << " bytes " << deltaSec * 1.e6 << " us" << std::endl;
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <mpi.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <mscclpp/npkit/npkit.hpp>
|
||||
|
||||
#include "mp_unit_tests.hpp"
|
||||
|
||||
@@ -30,9 +31,17 @@ void ExecutorTest::SetUp() {
|
||||
bootstrap->initialize(id);
|
||||
std::shared_ptr<mscclpp::Communicator> communicator = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
executor = std::make_shared<mscclpp::Executor>(communicator);
|
||||
npkitDumpDir = getenv("NPKIT_DUMP_DIR");
|
||||
if (npkitDumpDir != nullptr) {
|
||||
NpKit::Init(gEnv->rank);
|
||||
}
|
||||
}
|
||||
|
||||
void ExecutorTest::TearDown() {
|
||||
if (npkitDumpDir != nullptr) {
|
||||
NpKit::Dump(npkitDumpDir);
|
||||
NpKit::Shutdown();
|
||||
}
|
||||
executor.reset();
|
||||
MultiProcessTest::TearDown();
|
||||
}
|
||||
|
||||
@@ -170,5 +170,6 @@ class ExecutorTest : public MultiProcessTest {
|
||||
void TearDown() override;
|
||||
|
||||
std::shared_ptr<mscclpp::Executor> executor;
|
||||
const char* npkitDumpDir;
|
||||
};
|
||||
#endif // MSCCLPP_MP_UNIT_TESTS_HPP_
|
||||
|
||||
@@ -2,12 +2,34 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import json
|
||||
|
||||
from queue import Queue
|
||||
|
||||
|
||||
def parse_npkit_event_header(npkit_event_header_path):
|
||||
npkit_event_def = {"id_to_type": {}, "type_to_id": {}}
|
||||
executor_ops = [
|
||||
"BARRIER",
|
||||
"PUT",
|
||||
"PUT_PACKET",
|
||||
"GET",
|
||||
"COPY",
|
||||
"COPY_PACKET",
|
||||
"SIGNAL",
|
||||
"WAIT",
|
||||
"FLUSH",
|
||||
"REDUCE",
|
||||
"REDUCE_PACKET",
|
||||
"REDUCE_SEND",
|
||||
"REDUCE_SEND_PACKET",
|
||||
"READ_REDUCE_COPY",
|
||||
"READ_REDUCE_COPY_SEND",
|
||||
]
|
||||
executor_op_to_offset = {}
|
||||
for executor_op in executor_ops:
|
||||
executor_op_to_offset[executor_op] = len(executor_op_to_offset)
|
||||
with open(npkit_event_header_path, "r") as f:
|
||||
lines = [x.strip() for x in f.readlines() if len(x.strip()) != 0]
|
||||
line_idx = 0
|
||||
@@ -17,23 +39,22 @@ def parse_npkit_event_header(npkit_event_header_path):
|
||||
if len(fields) == 3:
|
||||
event_type = fields[1]
|
||||
event_id = int(fields[2], 0)
|
||||
npkit_event_def["type_to_id"][event_type] = event_id
|
||||
npkit_event_def["id_to_type"][event_id] = event_type
|
||||
if lines[line_idx].startswith("#define NPKIT_EVENT_EXECUTOR_OP_BASE"):
|
||||
for executor_op in executor_op_to_offset:
|
||||
real_event_id = event_id + executor_op_to_offset[executor_op]
|
||||
if "ENTRY" in lines[line_idx]:
|
||||
event_type = "NPKIT_EVENT_EXECUTOR_%s_ENTRY" % executor_op
|
||||
elif "EXIT" in lines[line_idx]:
|
||||
event_type = "NPKIT_EVENT_EXECUTOR_%s_EXIT" % executor_op
|
||||
npkit_event_def["type_to_id"][event_type] = real_event_id
|
||||
npkit_event_def["id_to_type"][real_event_id] = event_type
|
||||
else:
|
||||
npkit_event_def["type_to_id"][event_type] = event_id
|
||||
npkit_event_def["id_to_type"][event_id] = event_type
|
||||
line_idx += 1
|
||||
return npkit_event_def
|
||||
|
||||
|
||||
def trim_event_name(event_type):
|
||||
list_event_type_name = event_type.split("_")
|
||||
if "NPKIT" in list_event_type_name:
|
||||
list_event_type_name.remove("NPKIT")
|
||||
if "EVENT" in list_event_type_name:
|
||||
list_event_type_name.remove("EVENT")
|
||||
if "ENTRY" in list_event_type_name:
|
||||
list_event_type_name.remove("ENTRY")
|
||||
return "_".join(list_event_type_name)
|
||||
|
||||
|
||||
def parse_gpu_clock_scale(gpu_clock_file_path):
|
||||
with open(gpu_clock_file_path, "r") as f:
|
||||
freq_in_khz = f.read()
|
||||
@@ -103,7 +124,7 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo
|
||||
event_type_to_seq[event_type] = 0
|
||||
gpu_events[-1].update(
|
||||
{
|
||||
"name": trim_event_name(event_type),
|
||||
"name": event_type,
|
||||
"cat": "GPU",
|
||||
"args": {
|
||||
"rank": rank,
|
||||
@@ -116,12 +137,11 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo
|
||||
)
|
||||
event_type_to_seq[event_type] += 1
|
||||
else:
|
||||
gpu_events[-1]["args"] = {
|
||||
"size": parsed_gpu_event["size"],
|
||||
"rsvd": parsed_gpu_event["rsvd"],
|
||||
}
|
||||
gpu_events[-1]["args"] = {"size": parsed_gpu_event["size"], "rsvd": parsed_gpu_event["rsvd"]}
|
||||
delta_time = gpu_events[-1]["ts"] - gpu_events[-2]["ts"]
|
||||
gpu_events[-1]["args"]["bw (GB/s)"] = gpu_events[-1]["args"]["size"] / delta_time / 1e3
|
||||
gpu_events[-1]["args"]["bw (GB/s)"] = (
|
||||
0.0 if delta_time == 0.0 else gpu_events[-1]["args"]["size"] / delta_time / 1e3
|
||||
)
|
||||
raw_content_idx += raw_event_size
|
||||
return gpu_events
|
||||
|
||||
@@ -133,7 +153,7 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo
|
||||
event_type_to_seq = {}
|
||||
|
||||
fiber_is_usable = []
|
||||
fiber_open_info = []
|
||||
fiber_open_ts = []
|
||||
slot_to_fiber_id = {}
|
||||
channel_shift = 1000
|
||||
|
||||
@@ -156,17 +176,16 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo
|
||||
fiber_id += 1
|
||||
if fiber_id == len(fiber_is_usable):
|
||||
fiber_is_usable.append(True)
|
||||
fiber_open_info.append({"ts": 0.0, "size": 0})
|
||||
fiber_open_ts.append(0.0)
|
||||
slot_to_fiber_id[slot] = fiber_id
|
||||
fiber_open_info[fiber_id]["ts"] = cpu_events[-1]["ts"]
|
||||
fiber_open_info[fiber_id]["size"] = parsed_cpu_event["size"]
|
||||
fiber_open_ts[fiber_id] = cpu_events[-1]["ts"]
|
||||
fiber_is_usable[fiber_id] = False
|
||||
|
||||
if event_type not in event_type_to_seq:
|
||||
event_type_to_seq[event_type] = 0
|
||||
cpu_events[-1].update(
|
||||
{
|
||||
"name": trim_event_name(event_type),
|
||||
"name": event_type,
|
||||
"cat": "CPU",
|
||||
"args": {
|
||||
"rank": rank,
|
||||
@@ -182,16 +201,14 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo
|
||||
# Close fiber event
|
||||
fiber_id = slot_to_fiber_id[slot]
|
||||
slot_to_fiber_id.pop(slot)
|
||||
last_ts = fiber_open_info[fiber_id]["ts"]
|
||||
last_size = fiber_open_info[fiber_id]["size"]
|
||||
last_ts = fiber_open_ts[fiber_id]
|
||||
fiber_is_usable[fiber_id] = True
|
||||
|
||||
delta_time = max(0.001, cpu_events[-1]["ts"] - last_ts)
|
||||
cpu_events[-1]["args"] = {
|
||||
"size_1": parsed_cpu_event["size"],
|
||||
"size": max(last_size, parsed_cpu_event["size"]),
|
||||
}
|
||||
cpu_events[-1]["args"]["bw (GB/s)"] = cpu_events[-1]["args"]["size"] / delta_time / 1e3
|
||||
cpu_events[-1]["args"] = {"size": parsed_cpu_event["size"]}
|
||||
cpu_events[-1]["args"]["bw (GB/s)"] = (
|
||||
0.0 if delta_time == 0.0 else cpu_events[-1]["args"]["size"] / delta_time / 1e3
|
||||
)
|
||||
|
||||
cpu_events[-1]["tid"] = fiber_id + (channel + 1) * channel_shift
|
||||
|
||||
@@ -239,12 +256,7 @@ def convert_npkit_dump_to_trace(npkit_dump_dir, output_dir, npkit_event_def):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--npkit_dump_dir", type=str, required=True, help="NPKit dump directory.")
|
||||
parser.add_argument(
|
||||
"--npkit_event_header_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to npkit_event.h.",
|
||||
)
|
||||
parser.add_argument("--npkit_event_header_path", type=str, required=True, help="Path to npkit_event.h.")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user