mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
update
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
#include <mscclpp/fifo.hpp>
|
||||
#include <mscclpp/proxy.hpp>
|
||||
#include <mscclpp/semaphore.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -41,10 +40,10 @@ class ProxyService : public BaseProxyService {
|
||||
/// @return The ID of the semaphore.
|
||||
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);
|
||||
|
||||
/// Add a pitch pair to the proxy service.
|
||||
/// @param id The ID of the semaphore.
|
||||
/// Add a 2D channel to the proxy service.
|
||||
/// @param connection The connection associated with the channel.
|
||||
/// @param pitch The pitch pair.
|
||||
void addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch);
|
||||
SemaphoreId add2DChannel(std::shared_ptr<Connection> connection, std::pair<uint64_t, uint64_t> pitch);
|
||||
|
||||
/// Register a memory region with the proxy service.
|
||||
/// @param memory The memory region to register.
|
||||
@@ -71,7 +70,7 @@ class ProxyService : public BaseProxyService {
|
||||
Communicator& communicator_;
|
||||
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
std::unordered_map<SemaphoreId, std::pair<uint64_t, uint64_t>> pitches_;
|
||||
std::vector<std::pair<uint64_t, uint64_t>> pitches_;
|
||||
Proxy proxy_;
|
||||
int deviceNumaNode;
|
||||
|
||||
|
||||
@@ -29,8 +29,13 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connectio
|
||||
return semaphores_.size() - 1;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch) {
|
||||
MSCCLPP_API_CPP SemaphoreId ProxyService::add2DChannel(std::shared_ptr<Connection> connection,
|
||||
std::pair<uint64_t, uint64_t> pitch) {
|
||||
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator_, connection));
|
||||
SemaphoreId id = semaphores_.size() - 1;
|
||||
if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair<uint64_t, uint64_t>(0, 0));
|
||||
pitches_[id] = pitch;
|
||||
return id;
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
|
||||
|
||||
@@ -58,8 +58,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(
|
||||
|
||||
communicator->setup();
|
||||
|
||||
mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
|
||||
channelService->addPitch(cid, std::pair<size_t, size_t>(pitch, pitch));
|
||||
mscclpp::SemaphoreId cid = channelService->add2DChannel(conn, std::pair<size_t, size_t>(pitch, pitch));
|
||||
communicator->setup();
|
||||
|
||||
proxyChannels.emplace_back(mscclpp::deviceHandle(
|
||||
@@ -77,13 +76,13 @@ __device__ size_t getTileElementOffset(int elementId, int width, int rowIndex, i
|
||||
}
|
||||
|
||||
__global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowIndex, int colIndex, int width,
|
||||
int hight, int* ret) {
|
||||
int height, int* ret) {
|
||||
DeviceHandle<mscclpp::SimpleProxyChannel>& proxyChan = gChannelOneToOneTestConstProxyChans;
|
||||
volatile int* sendBuff = (volatile int*)buff;
|
||||
int nTries = 1000;
|
||||
int flusher = 0;
|
||||
size_t offset = rowIndex * pitch + colIndex * sizeof(int);
|
||||
size_t nElem = width * hight;
|
||||
size_t nElem = width * height;
|
||||
size_t nElemPerPitch = pitch / sizeof(int);
|
||||
for (int i = 0; i < nTries; i++) {
|
||||
if (rank == 0) {
|
||||
@@ -105,7 +104,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI
|
||||
}
|
||||
__syncthreads();
|
||||
// __threadfence_system(); // not necessary if we make sendBuff volatile
|
||||
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight);
|
||||
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height);
|
||||
}
|
||||
if (rank == 1) {
|
||||
if (threadIdx.x == 0) proxyChan.wait();
|
||||
@@ -125,7 +124,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI
|
||||
}
|
||||
__syncthreads();
|
||||
// __threadfence_system(); // not necessary if we make sendBuff volatile
|
||||
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight);
|
||||
if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height);
|
||||
}
|
||||
}
|
||||
flusher++;
|
||||
|
||||
Reference in New Issue
Block a user