mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
Graph parallel: the next generation (#1080)
* WIP: absorb adding input into std_attn and std_ffn * WIP: NCCL infra * WIP: add reduce and fake_cpy ops * WIP * WIP: graph appears to work, layer is broken * WIP: Qwen3-MoE works with graph, layer still broken * WIP: GLM-4.5 graph works * WIP: fix sm layer (dense) * WIP: fix sm layer (MoE) * WIP: fast PP with bespoke 4-GPU NCCL I guess, I'm not using NCCL the right way as PP is very low with a single communicator group for 3 or more GPUs. But if I create 4 communicator groups for pairs of GPUs (0,1, 2,3, 0,2, 1,3) and use that, PP is fast: I'm hitting 1500 t/s for L3-70B on the 4x3090 system, which is ~20% better than the previous sm graph without NCCL. But that cannot be the solution (I cannot be creating pairwise communicators and associated logic for every possible number of GPUs). * WIP: Cohere2 * Explicitely set device * Bespoke 3-GPU case * WIP * Do not repeat get_rows multiple times * Fix 3 GPUs * OK, let's leave it in * Implement the reduce op without NCCL available * Be able to build without NCCL cmake -DGGML_NCCL=OFF disables it * Make --max-gpu work again * Slightly better for 4 GPUs without NCCL * Cleanup --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -97,6 +97,7 @@ endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_SVE "ggml: enable SVE" OFF)
|
||||
option(GGML_NCCL "ggml: enable NCCL" ON)
|
||||
|
||||
if (WIN32)
|
||||
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")
|
||||
|
||||
@@ -689,6 +689,9 @@ extern "C" {
|
||||
|
||||
GGML_OP_GLU,
|
||||
|
||||
GGML_OP_REDUCE,
|
||||
GGML_OP_FAKE_CPY,
|
||||
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
|
||||
@@ -3034,6 +3037,17 @@ extern "C" {
|
||||
struct ggml_tensor ** splits;
|
||||
} ggml_split_tensor_t;
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_reduce(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor ** a,
|
||||
int n,
|
||||
enum ggml_op op);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fake_cpy(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * dst,
|
||||
struct ggml_tensor * src);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -462,6 +462,21 @@ if (GGML_CUDA)
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_NCCL)
|
||||
find_package(NCCL)
|
||||
if (NCCL_FOUND)
|
||||
message("==================== NCCL found!")
|
||||
message("NCCL_LIBRARIES = ${NCCL_LIBRARIES}")
|
||||
message("NCCL_INCLUDE_DIRS = ${NCCL_INCLUDE_DIRS}")
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${NCCL_LIBRARIES})
|
||||
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${NCCL_INCLUDE_DIRS})
|
||||
add_compile_definitions(GGML_USE_NCCL)
|
||||
else()
|
||||
message("==================== NCCL NOT found -> building wihout NCCL support")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT GGML_MUSA)
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0)
|
||||
|
||||
@@ -1414,13 +1414,59 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// do not overwrite user assignments
|
||||
if (*leaf_backend_id == -1) {
|
||||
*leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
|
||||
//printf("Pass 1: assigned backend %d to leaf %d, %s\n", *leaf_backend_id, i, graph->leafs[i]->name);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < graph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = graph->nodes[i];
|
||||
int * node_backend_id = &tensor_backend_id(node);
|
||||
if (node->op == GGML_OP_REDUCE) {
|
||||
auto view_src = node->view_src;
|
||||
int src_id = -1;
|
||||
for (int j = 0; j < node->op_params[1]; ++j) {
|
||||
if (node->src[j]) {
|
||||
int * this_node_backend_id = &tensor_backend_id(node->src[j]);
|
||||
if (*this_node_backend_id == -1) {
|
||||
*this_node_backend_id = j;
|
||||
} else {
|
||||
GGML_ASSERT(*this_node_backend_id == j);
|
||||
}
|
||||
if (view_src == node->src[j]) {
|
||||
src_id = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (src_id >= 0) {
|
||||
int * this_node_backend_id = &tensor_backend_id(view_src);
|
||||
*this_node_backend_id = tensor_backend_id(node->src[src_id]);
|
||||
*node_backend_id = *this_node_backend_id;
|
||||
}
|
||||
}
|
||||
else if (node->op == GGML_OP_MUL && node->src[0]->op == GGML_OP_NORM) {
|
||||
// This is a hack for Cohere2. Without this hack the scheduler creates
|
||||
// totally nonsensical splits for that arch
|
||||
int * src1_id = &tensor_backend_id(node->src[1]);
|
||||
if (*src1_id >= 0) {
|
||||
int * src0_id = &tensor_backend_id(node->src[0]);
|
||||
int * dst_id = &tensor_backend_id(node);
|
||||
*src0_id = *src1_id;
|
||||
*dst_id = *src1_id;
|
||||
// For some reason that I don't understand, we can have norm backend already assigned
|
||||
// at this point. How? That's why this more logical approach of first checking is commented out
|
||||
//if (*src0_id < 0) {
|
||||
// *src0_id = *src1_id;
|
||||
//} else {
|
||||
// printf("Oops: backend_id_src0(%s) = %d, backend_id_src1(%s) = %d\n", node->src[0]->name, *src0_id, node->src[1]->name, *src1_id);
|
||||
// //GGML_ASSERT(*src0_id == *src1_id);
|
||||
//}
|
||||
//if (*dst_id < 0) {
|
||||
// *dst_id = *src1_id;
|
||||
//} else {
|
||||
// printf("Oops: backend_id_dst(%s) = %d, backend_id_src1(%s) = %d\n", node->name, *dst_id, node->src[1]->name, *src1_id);
|
||||
// //GGML_ASSERT(*dst_id == *src1_id);
|
||||
//}
|
||||
}
|
||||
}
|
||||
// do not overwrite user assignments
|
||||
if (*node_backend_id == -1) {
|
||||
*node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
|
||||
@@ -1652,6 +1698,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// check if we should start a new split based on the sources of the current node
|
||||
bool need_new_split = false;
|
||||
if ((node->op == GGML_OP_ADD && node->op_params[0] == 0xff) ||
|
||||
node->op == GGML_OP_REDUCE ||
|
||||
node->op == GGML_OP_FAKE_CPY ||
|
||||
node->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] == 0xff) {
|
||||
need_new_split = true;
|
||||
}
|
||||
@@ -1739,6 +1787,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
|
||||
// create a copy of the input in the split's backend
|
||||
if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
|
||||
if (node->op == GGML_OP_REDUCE) {
|
||||
//printf("setting tensor_id_copy(reduce, %zu, %d, %s) to %s\n", src_id, cur_backend_id, node->name, src->name);
|
||||
tensor_id_copy(src_id, cur_backend_id, 0) = src;
|
||||
} else if (node->op == GGML_OP_FAKE_CPY && src->op == GGML_OP_REDUCE) {
|
||||
//printf("setting tensor_id_copy(fake_cpy, %zu, %d, %s) to %s\n", src_id, cur_backend_id, node->name, src->src[j]->name);
|
||||
tensor_id_copy(src_id, cur_backend_id, 0) = src->src[j];
|
||||
} else {
|
||||
ggml_backend_t backend = sched->backends[cur_backend_id];
|
||||
for (int c = 0; c < sched->n_copies; c++) {
|
||||
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
|
||||
@@ -1753,6 +1808,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
int n_inputs = split->n_inputs++;
|
||||
GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
|
||||
split->inputs[n_inputs] = src;
|
||||
}
|
||||
}
|
||||
node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
|
||||
}
|
||||
@@ -2027,80 +2083,8 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_sched_compute_splits_sm_graph(ggml_backend_sched_t sched) {
|
||||
std::vector<int32_t> ids;
|
||||
std::vector<uint32_t> unique_ids;
|
||||
ggml_tensor * last_ids_tensor = nullptr;
|
||||
|
||||
std::array<bool, GGML_SCHED_MAX_BACKENDS> needs_sync{{true}};
|
||||
|
||||
auto splits = sched->splits;
|
||||
|
||||
std::vector<ggml_backend_sched_split *> this_split;
|
||||
for (int i = 0; i < sched->n_splits; ++i) {
|
||||
auto split_i = &splits[i];
|
||||
this_split.clear();
|
||||
this_split.push_back(split_i);
|
||||
for (int j = i+1; j < sched->n_splits; ++j) {
|
||||
auto split_j = &splits[j];
|
||||
if (split_i->backend_id == split_j->backend_id) {
|
||||
break;
|
||||
}
|
||||
int n_nodes = std::min(split_i->graph.n_nodes, split_j->graph.n_nodes);
|
||||
bool same = true;
|
||||
for (int k = 0; k < n_nodes; ++k) {
|
||||
if (split_i->graph.nodes[k]->op != split_j->graph.nodes[k]->op) {
|
||||
same = false; break;
|
||||
}
|
||||
}
|
||||
if (!same) {
|
||||
break;
|
||||
}
|
||||
this_split.push_back(split_j);
|
||||
}
|
||||
if (false) {
|
||||
auto split = this_split.front();
|
||||
if (this_split.size() == 1) {
|
||||
printf("=== Split %d with %d inputs on backend %d\n", i, split->n_inputs, split->backend_id);
|
||||
} else {
|
||||
printf("=== Split %d with %d inputs on backends", i, split->n_inputs);
|
||||
for (int j = 0; j < (int)this_split.size(); ++j) printf(" %d", this_split[j]->backend_id);
|
||||
printf("\n");
|
||||
}
|
||||
for (int j = 0; j < split->graph.n_nodes; ++j) {
|
||||
printf(" %d %s(%s)\n", j, ggml_op_name(split->graph.nodes[j]->op), split->graph.nodes[j]->name);
|
||||
}
|
||||
}
|
||||
for (auto split : this_split) {
|
||||
ggml_backend_sched_copy_inputs(sched, split, needs_sync, ids, unique_ids, last_ids_tensor);
|
||||
}
|
||||
for (auto split : this_split) {
|
||||
auto split_backend_id = split->backend_id;
|
||||
if (split->n_inputs > 0) {
|
||||
needs_sync[split_backend_id] = true;
|
||||
}
|
||||
auto split_backend = sched->backends[split_backend_id];
|
||||
auto ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
|
||||
if (ec != GGML_STATUS_SUCCESS) {
|
||||
return ec;
|
||||
}
|
||||
if (split->n_inputs > 0) {
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
i += this_split.size() - 1;
|
||||
}
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
||||
|
||||
if (false && sched->split_mode_graph) {
|
||||
return ggml_backend_sched_compute_splits_sm_graph(sched);
|
||||
}
|
||||
|
||||
std::array<bool, GGML_SCHED_MAX_BACKENDS> needs_sync{{true}};
|
||||
std::array<bool, GGML_SCHED_MAX_BACKENDS> own_cpy{{false}};
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@
|
||||
#include "ggml-cuda/argmax.cuh"
|
||||
#include "ggml-cuda/multiadd.cuh"
|
||||
#include "ggml-cuda/hadamard.cuh"
|
||||
#include "ggml-cuda/reduce.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -143,7 +144,7 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
|
||||
auto res = hipMallocManaged(ptr, size);
|
||||
@@ -246,6 +247,42 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
info.have_nccl = false;
|
||||
if (info.device_count > 1) {
|
||||
int gpu_list[GGML_CUDA_MAX_DEVICES];
|
||||
for(int i = 0; i < info.device_count; ++i) gpu_list[i] = i;
|
||||
auto status = ncclCommInitAll(info.nccl_coms, info.device_count, gpu_list);
|
||||
if (status == ncclSuccess) {
|
||||
printf("=============================== NCCL main communicator initialized\n");
|
||||
info.have_nccl = true;
|
||||
} else {
|
||||
printf("=============================== NCCL initialization failed with status %d\n", int(status));
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
auto com = info.nccl_coms + info.device_count;
|
||||
if (info.device_count == 4) {
|
||||
int devs[8] = {0,1, 2,3, 0,2, 1,3};
|
||||
auto com = info.nccl_coms + info.device_count;
|
||||
for (int ip = 0; ip < 4; ++ip) {
|
||||
if (auto status = ncclCommInitAll(com+2*ip, 2, devs+2*ip); status != ncclSuccess) {
|
||||
printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status));
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
printf("=============================== NCCL pair communicators for %d GPUs initialized\n", info.device_count);
|
||||
} else if (info.device_count == 3) {
|
||||
int devs[4] = {0,1, 0,2};
|
||||
for (int ip = 0; ip < 2; ++ip) {
|
||||
if (auto status = ncclCommInitAll(com+2*ip, 2, devs+2*ip); status != ncclSuccess) {
|
||||
printf("=============================== NCCL initialization of pair %d failed with status %d\n", ip, int(status));
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
printf("=============================== NCCL pair communicators for %d GPUs initialized\n", info.device_count);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return info;
|
||||
}
|
||||
|
||||
@@ -465,6 +502,11 @@ static std::atomic<int> ggml_cuda_lock_counter;
|
||||
|
||||
ggml_backend_cuda_context::ggml_backend_cuda_context(int device) :
|
||||
device(device), name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
auto info = const_cast<ggml_cuda_device_info*>(&ggml_cuda_info());
|
||||
if (info->all_ctx[device]) {
|
||||
GGML_CUDA_LOG_WARN("%s: a context for device %d already exists?\n", __func__, device);
|
||||
}
|
||||
info->all_ctx[device] = this;
|
||||
}
|
||||
|
||||
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||
@@ -472,6 +514,9 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
|
||||
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
|
||||
|
||||
auto info = const_cast<ggml_cuda_device_info*>(&ggml_cuda_info());
|
||||
info->all_ctx[this->device] = nullptr;
|
||||
|
||||
if (copy_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||
}
|
||||
@@ -2934,6 +2979,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
|
||||
//printf("%4d %s(%s) on device %d. time = %ld\n", i, ggml_op_name(dst->op), dst->name, ctx.device, ggml_time_us());
|
||||
switch (dst->op) {
|
||||
case GGML_OP_REDUCE:
|
||||
ggml_cuda_op_reduce(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FAKE_CPY:
|
||||
break;
|
||||
case GGML_OP_ARGMAX:
|
||||
ggml_cuda_argmax(ctx, dst);
|
||||
break;
|
||||
@@ -3451,8 +3501,23 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
|
||||
needs_f16_f32_copy = true;
|
||||
|
||||
} else {
|
||||
#ifdef GGML_USE_NCCL__
|
||||
auto & info = ggml_cuda_info();
|
||||
auto nbytes = ggml_nbytes(src);
|
||||
ncclGroupStart();
|
||||
ggml_cuda_set_device(cuda_ctx_src->device);
|
||||
auto status1 = ncclSend(src->data, nbytes, ncclUint8, cuda_ctx_dst->device, info.nccl_coms[cuda_ctx_src->device],
|
||||
info.all_ctx[cuda_ctx_src->device]->stream());
|
||||
ggml_cuda_set_device(cuda_ctx_dst->device);
|
||||
auto status2 = ncclRecv(dst->data, nbytes, ncclUint8, cuda_ctx_src->device, info.nccl_coms[cuda_ctx_dst->device],
|
||||
info.all_ctx[cuda_ctx_dst->device]->stream());
|
||||
ncclGroupEnd();
|
||||
GGML_ASSERT(status1 == ncclSuccess && status2 == ncclSuccess);
|
||||
return true;
|
||||
#else
|
||||
ggml_cuda_set_device(cuda_ctx_src->device);
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -3733,12 +3798,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
#endif
|
||||
#ifndef NDEBUG
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
}
|
||||
}
|
||||
//assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
//for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
// if (node->src[j] != nullptr) {
|
||||
// assert(node->src[j]->buffer);
|
||||
// }
|
||||
//}
|
||||
#endif // NDEBUG
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
|
||||
@@ -4044,6 +4109,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_REDUCE:
|
||||
case GGML_OP_FAKE_CPY:
|
||||
case GGML_OP_ARGMAX:
|
||||
return true;
|
||||
case GGML_OP_HADAMARD:
|
||||
@@ -4372,6 +4439,13 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
if (!enable_p2p) {
|
||||
printf("================== P2P disabled, but needed for NCCL\n");
|
||||
enable_p2p = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_CUDA_NO_PEER_COPY)
|
||||
if (enable_p2p) {
|
||||
ggml_cuda_set_peer_access(device);
|
||||
|
||||
@@ -34,6 +34,10 @@
|
||||
#include "vendors/cuda.h"
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
#include <nccl.h>
|
||||
#endif
|
||||
|
||||
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
||||
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
||||
|
||||
@@ -738,6 +742,8 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K_R4> {
|
||||
|
||||
//////////////////////
|
||||
|
||||
struct ggml_backend_cuda_context;
|
||||
|
||||
struct ggml_cuda_device_info {
|
||||
int device_count;
|
||||
|
||||
@@ -754,6 +760,12 @@ struct ggml_cuda_device_info {
|
||||
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
|
||||
|
||||
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
|
||||
|
||||
ggml_backend_cuda_context * all_ctx[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
||||
#ifdef GGML_USE_NCCL
|
||||
ncclComm_t nccl_coms[GGML_CUDA_MAX_DEVICES];
|
||||
bool have_nccl;
|
||||
#endif
|
||||
};
|
||||
|
||||
const ggml_cuda_device_info & ggml_cuda_info();
|
||||
@@ -844,6 +856,9 @@ struct ggml_backend_cuda_context {
|
||||
bool use_cuda_graph = true;
|
||||
#endif
|
||||
|
||||
void * copy_buffer = nullptr;
|
||||
size_t copy_size = 0;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device);
|
||||
|
||||
~ggml_backend_cuda_context();
|
||||
@@ -889,3 +904,5 @@ struct ggml_backend_cuda_context {
|
||||
return pool(device);
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device);
|
||||
|
||||
319
ggml/src/ggml-cuda/reduce.cu
Normal file
319
ggml/src/ggml-cuda/reduce.cu
Normal file
@@ -0,0 +1,319 @@
|
||||
//
|
||||
// Copyright (C) 2023-2024 The ggml authors
|
||||
// Copyright (C) 2024 Iwan Kawrakow
|
||||
// MIT license
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
#include "reduce.cuh"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
template <typename T, int block_size>
|
||||
static __global__ void k_add(int nelem, const T * src, T * dst) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
if (i >= nelem) return;
|
||||
dst[i] += src[i];
|
||||
}
|
||||
|
||||
template <typename T, int block_size>
|
||||
static __global__ void k_add_sym(int nelem, T * src, T * dst) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
if (i >= nelem) return;
|
||||
dst[i] += src[i];
|
||||
src[i] = dst[i];
|
||||
}
|
||||
|
||||
struct copy_task {
|
||||
void * ptrs[GGML_CUDA_MAX_DEVICES];
|
||||
int nptr;
|
||||
int nelem;
|
||||
};
|
||||
|
||||
template <typename T, int block_size>
|
||||
static __global__ void k_reduce_add(copy_task task) {
|
||||
int i = blockIdx.x*block_size + threadIdx.x;
|
||||
if (i >= task.nelem) return;
|
||||
auto dst = (T *)task.ptrs[0];
|
||||
for (int j = 1; j < task.nptr; ++j) {
|
||||
auto src = (T *)task.ptrs[j];
|
||||
dst[i] += src[i];
|
||||
}
|
||||
for (int j = 1; j < task.nptr; ++j) {
|
||||
auto src = (T *)task.ptrs[j];
|
||||
src[i] = dst[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
auto op = (ggml_op)dst->op_params[0];
|
||||
GGML_ASSERT(op == GGML_OP_ADD);
|
||||
int nreduce = dst->op_params[1];
|
||||
int nhave = dst->op_params[2];
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(nhave >=2 && nhave <= nreduce);
|
||||
|
||||
auto & info = ggml_cuda_info();
|
||||
#ifdef GGML_USE_NCCL
|
||||
if (info.have_nccl && nhave == nreduce) { // somehow I'm not able to figure out how to use NCCL when not all GPUs participate in the reduce op
|
||||
GGML_ASSERT(info.have_nccl);
|
||||
GGML_ASSERT(info.device_count == nreduce);
|
||||
auto type = dst->type;
|
||||
//int device = ctx.device;
|
||||
if (nreduce != info.device_count) {
|
||||
GGML_ABORT("Not implemented");
|
||||
}
|
||||
//auto tim1 = std::chrono::steady_clock::now();
|
||||
auto data_type = type == GGML_TYPE_F32 ? ncclFloat : ncclHalf;
|
||||
if (nreduce == 4 && dst->ne[1] > 32) {
|
||||
auto com = info.nccl_coms + info.device_count;
|
||||
static const int devs[8] = {0,1, 2,3, 0,2, 1,3};
|
||||
for (int ip = 0; ip < 4; ++ip) {
|
||||
ncclGroupStart();
|
||||
ggml_cuda_set_device(devs[2*ip+0]);
|
||||
auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream());
|
||||
ggml_cuda_set_device(devs[2*ip+1]);
|
||||
auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream());
|
||||
ncclGroupEnd();
|
||||
if (status1 != ncclSuccess || status2 != ncclSuccess) {
|
||||
fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (nreduce == 3 && dst->ne[1] > 32) {
|
||||
auto com = info.nccl_coms + info.device_count;
|
||||
static const int devs[4] = {0,1, 0,2};
|
||||
for (int ip = 0; ip < 2; ++ip) {
|
||||
ncclGroupStart();
|
||||
ggml_cuda_set_device(devs[2*ip+0]);
|
||||
auto status1 = ncclAllReduce(dst->src[devs[2*ip+0]]->data, dst->src[devs[2*ip+0]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, com[2*ip+0], info.all_ctx[devs[2*ip+0]]->stream());
|
||||
ggml_cuda_set_device(devs[2*ip+1]);
|
||||
auto status2 = ncclAllReduce(dst->src[devs[2*ip+1]]->data, dst->src[devs[2*ip+1]]->data,
|
||||
ggml_nelements(dst), data_type, ncclSum, com[2*ip+1], info.all_ctx[devs[2*ip+1]]->stream());
|
||||
ncclGroupEnd();
|
||||
if (status1 != ncclSuccess || status2 != ncclSuccess) {
|
||||
fprintf(stderr, "%s: ncclAllReduce failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
ncclGroupStart();
|
||||
ggml_cuda_set_device(0);
|
||||
auto status1 = ncclSend(dst->src[0]->data, ggml_nelements(dst), data_type, 1, com[0], info.all_ctx[0]->stream());
|
||||
ggml_cuda_set_device(1);
|
||||
auto status2 = ncclRecv(dst->src[1]->data, ggml_nelements(dst), data_type, 0, com[1], info.all_ctx[1]->stream());
|
||||
ncclGroupEnd();
|
||||
if (status1 != ncclSuccess || status2 != ncclSuccess) {
|
||||
fprintf(stderr, "%s: ncclSend/Recv failed with statuses %d, %d\n", __func__, (int)status1, (int)status2);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
else {
|
||||
ncclGroupStart();
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
ncclComm_t this_comm;
|
||||
if (nhave == nreduce) {
|
||||
this_comm = info.nccl_coms[i];
|
||||
} else {
|
||||
auto status = ncclCommSplit(info.nccl_coms[i], dst->src[i] ? 0 : NCCL_SPLIT_NOCOLOR, i, &this_comm, NULL);
|
||||
GGML_ASSERT(status == ncclSuccess);
|
||||
}
|
||||
ggml_cuda_set_device(i);
|
||||
auto stream = info.all_ctx[i]->stream();
|
||||
GGML_ASSERT(stream);
|
||||
auto status = ncclAllReduce(dst->src[i] ? dst->src[i]->data : nullptr,
|
||||
dst->src[i] ? dst->src[i]->data : nullptr,
|
||||
ggml_nelements(dst), data_type, ncclSum, this_comm, stream);
|
||||
if (status != ncclSuccess) {
|
||||
fprintf(stderr, "%s: ncclAllReduce failed with status %d\n", __func__, (int)status);
|
||||
GGML_ABORT("Fatal error");
|
||||
}
|
||||
}
|
||||
ncclGroupEnd();
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(dst->data == dst->src[ctx.device]->data);
|
||||
auto nbytes = ggml_nbytes(dst);
|
||||
if (nhave == 2 && (nhave == nreduce || dst->ne[1] <= 8)) {
|
||||
int idx[2];
|
||||
int ii = 0;
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
if (dst->src[i]) {
|
||||
idx[ii++] = i;
|
||||
}
|
||||
}
|
||||
// With P2P access enabled, we can access peer memory so as if it was local.
|
||||
// Hence, we can launch two reduce kernels, one on each device, each kernel
|
||||
// processing half of the data. This very simply approach almost matches NCCL
|
||||
// performance (I see ~1% lower PP and TG performance on my 2x3090 system).
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
GGML_ASSERT(dst->src[idx[i]]->type == dst->type);
|
||||
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[idx[i]]));
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
if (!info.all_ctx[idx[i]]->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[idx[i]]->copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
|
||||
}
|
||||
auto nelem = ggml_nelements(dst);
|
||||
auto nelem_half = (nelem + 1)/2;
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[i]]->stream(), info.all_ctx[idx[(i+1)%2]]->copy_event, 0));
|
||||
auto this_nelem = std::min(nelem_half, nelem - nelem_half);
|
||||
int nblock = (this_nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
auto src_ptr = (half *)dst->src[idx[i]]->data + i*nelem_half;
|
||||
auto dst_ptr = (half *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
|
||||
k_add_sym<half, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
|
||||
} else {
|
||||
auto src_ptr = (float *)dst->src[idx[i]]->data + i*nelem_half;
|
||||
auto dst_ptr = (float *)dst->src[idx[(i+1)%2]]->data + i*nelem_half;
|
||||
k_add_sym<float, CUDA_REDUCE_BLOCK_SIZE><<<nblock, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[idx[i]]->stream()>>>(this_nelem, src_ptr, dst_ptr);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nhave; ++i) {
|
||||
ggml_cuda_set_device(idx[i]);
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[idx[i]]->copy_event, info.all_ctx[idx[i]]->stream()));
|
||||
ggml_cuda_set_device(idx[(i+1)%2]);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[idx[(i+1)%2]]->stream(), info.all_ctx[idx[i]]->copy_event));
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
return;
|
||||
}
|
||||
int idx[GGML_CUDA_MAX_DEVICES];
|
||||
{
|
||||
int ii = 0;
|
||||
bool have_this_device = false;
|
||||
for (int i = 0; i < nreduce; ++i) {
|
||||
if (dst->src[i]) {
|
||||
idx[ii++] = i;
|
||||
if (i == ctx.device) have_this_device = true;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(ii == nhave);
|
||||
GGML_ASSERT(have_this_device);
|
||||
}
|
||||
if (nhave == 4 && dst->ne[1] <= 8) {
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
GGML_ASSERT(dst->src[i]->type == dst->type);
|
||||
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[i]));
|
||||
ggml_cuda_set_device(i);
|
||||
if (!info.all_ctx[i]->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[i]->copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
}
|
||||
auto nelem = ggml_nelements(dst);
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
int i = idx[2*ii+0];
|
||||
ggml_cuda_set_device(i);
|
||||
int nblocks = (nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
copy_task task;
|
||||
task.nptr = nhave/2;
|
||||
task.nelem = nelem;
|
||||
task.ptrs[0] = (char *)dst->src[i]->data;
|
||||
int j = idx[2*ii+1];
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[j]->copy_event, info.all_ctx[j]->stream()));
|
||||
task.ptrs[1] = (char *)dst->src[j]->data;
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
} else {
|
||||
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
int i = idx[2*ii+0];
|
||||
ggml_cuda_set_device(i);
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
int i = idx[2*ii+1];
|
||||
ggml_cuda_set_device(i);
|
||||
int nblocks = (nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
copy_task task;
|
||||
task.nptr = nhave/2;
|
||||
task.nelem = nelem;
|
||||
task.ptrs[0] = (char *)dst->src[i]->data;
|
||||
int j = idx[(2*ii+2)%nhave];
|
||||
task.ptrs[1] = (char *)dst->src[j]->data;
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_reduce_add<half, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
} else {
|
||||
k_reduce_add<float, CUDA_REDUCE_BLOCK_SIZE><<<nblocks, CUDA_REDUCE_BLOCK_SIZE, 0, info.all_ctx[i]->stream()>>>(task);
|
||||
}
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
int i = idx[2*ii+1];
|
||||
ggml_cuda_set_device(i);
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
|
||||
}
|
||||
for (int ii = 0; ii < nhave/2; ++ii) {
|
||||
int i = idx[(2*ii+2)%nhave];
|
||||
ggml_cuda_set_device(i);
|
||||
int j = idx[2*ii+1];
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[j]->copy_event));
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
return;
|
||||
}
|
||||
auto required_size = nbytes*(nhave-1);
|
||||
if (required_size > ctx.copy_size) {
|
||||
if (ctx.copy_buffer) {
|
||||
CUDA_CHECK(cudaFree(ctx.copy_buffer));
|
||||
}
|
||||
CUDA_CHECK(ggml_cuda_device_malloc(&ctx.copy_buffer, required_size, ctx.device));
|
||||
ctx.copy_size = required_size;
|
||||
}
|
||||
auto ptr = (char *)ctx.copy_buffer;
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
GGML_ASSERT(dst->src[i]->type == dst->type);
|
||||
GGML_ASSERT(ggml_are_same_shape(dst, dst->src[i]));
|
||||
if (i == ctx.device) continue;
|
||||
ggml_cuda_set_device(i);
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync(ptr, ctx.device, dst->src[i]->data, i, nbytes, info.all_ctx[i]->stream()));
|
||||
if (!info.all_ctx[i]->copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&info.all_ctx[i]->copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream()));
|
||||
ptr += nbytes;
|
||||
}
|
||||
auto nelem = ggml_nelements(dst);
|
||||
int num_blocks = (nelem + CUDA_REDUCE_BLOCK_SIZE - 1)/CUDA_REDUCE_BLOCK_SIZE;
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
ptr = (char *)ctx.copy_buffer;
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
if (i == ctx.device) continue;
|
||||
CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), info.all_ctx[i]->copy_event, 0));
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
k_add<half, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem, (const half *)ptr, (half *)dst->data);
|
||||
} else {
|
||||
k_add<float, CUDA_REDUCE_BLOCK_SIZE><<<num_blocks, CUDA_REDUCE_BLOCK_SIZE, 0, ctx.stream()>>>(nelem, (const float *)ptr, (float *)dst->data);
|
||||
}
|
||||
ptr += nbytes;
|
||||
}
|
||||
if (!ctx.copy_event) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&ctx.copy_event, cudaEventDisableTiming));
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(ctx.copy_event, ctx.stream()));
|
||||
for (int ii = 0; ii < nhave; ++ii) {
|
||||
int i = idx[ii];
|
||||
if (i == ctx.device) continue;
|
||||
ggml_cuda_set_device(i);
|
||||
CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), ctx.copy_event, 0));
|
||||
CUDA_CHECK(cudaMemcpyPeerAsync(dst->src[i]->data, i, dst->data, ctx.device, nbytes, info.all_ctx[i]->stream()));
|
||||
}
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
}
|
||||
7
ggml/src/ggml-cuda/reduce.cuh
Normal file
7
ggml/src/ggml-cuda/reduce.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_REDUCE_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_reduce(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fake_cpy(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -4291,9 +4291,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
|
||||
"GLU",
|
||||
|
||||
"REDUCE",
|
||||
"FAKE_CPY",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4398,10 +4401,13 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
|
||||
"glu(x),"
|
||||
"glu(x),",
|
||||
|
||||
"reduce(x1,x2,...)",
|
||||
"fake_cpy(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -6060,6 +6066,43 @@ struct ggml_tensor * ggml_dup_inplace(
|
||||
return ggml_dup_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reduce(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor ** a,
|
||||
int n,
|
||||
enum ggml_op op) {
|
||||
GGML_ASSERT(n > 1 && n <= GGML_MAX_SRC);
|
||||
GGML_ASSERT(op == GGML_OP_ADD); // currently we only handle reduce_add
|
||||
struct ggml_tensor * last = NULL;
|
||||
int nhave = 0;
|
||||
for (int j = 0; j < n; ++j) {
|
||||
if (a[j]) { ++nhave; last = a[j]; }
|
||||
}
|
||||
GGML_ASSERT(last);
|
||||
GGML_ASSERT(nhave > 1);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, last);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
result->src[j] = a[j];
|
||||
}
|
||||
result->op = GGML_OP_REDUCE;
|
||||
result->op_params[0] = (int)op;
|
||||
result->op_params[1] = n;
|
||||
result->op_params[2] = nhave;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fake_cpy(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * dst,
|
||||
struct ggml_tensor * src) {
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, dst);
|
||||
result->op = GGML_OP_FAKE_CPY;
|
||||
result->src[0] = dst;
|
||||
result->src[1] = src;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// ggml_add
|
||||
|
||||
static struct ggml_tensor * ggml_add_impl(
|
||||
@@ -8433,6 +8476,21 @@ struct ggml_tensor * ggml_get_rows(
|
||||
if (a->type == GGML_TYPE_I32) {
|
||||
type = a->type;
|
||||
}
|
||||
|
||||
//if (a->op == GGML_OP_REDUCE) {
|
||||
// //printf("======================= %s(%s)\n", __func__, a->name);
|
||||
// struct ggml_tensor * result = NULL;
|
||||
// for (int j = a->op_params[1]-1; j >= 0; --j) {
|
||||
// if (a->src[j]) {
|
||||
// struct ggml_tensor * aj = ggml_get_rows(ctx, a->src[j], b);
|
||||
// if (result == NULL) result = ggml_view_tensor(ctx, aj);
|
||||
// result->src[j] = aj;
|
||||
// }
|
||||
// }
|
||||
// GGML_ASSERT(result);
|
||||
// return result;
|
||||
//}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
|
||||
|
||||
result->op = GGML_OP_GET_ROWS;
|
||||
@@ -22675,6 +22733,14 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
#endif
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REDUCE:
|
||||
{
|
||||
GGML_ABORT("REDUCE not implemented");
|
||||
}
|
||||
case GGML_OP_FAKE_CPY:
|
||||
{
|
||||
GGML_ABORT("FAKE_CPY not implemented");
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
ggml_compute_forward_dup(params, tensor);
|
||||
@@ -23352,6 +23418,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REDUCE:
|
||||
{
|
||||
GGML_ABORT("REDUCE not implemented");
|
||||
}
|
||||
case GGML_OP_FAKE_CPY:
|
||||
{
|
||||
GGML_ABORT("FAKE_CPY not implemented");
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
if (src0->grad) {
|
||||
|
||||
Reference in New Issue
Block a user