mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-07 08:14:14 +00:00
Complete SFT (Supervised Fine-Tuning) backend for MoE models using AMX SIMD: Core C++ implementation: - sft_moe.hpp: Forward/backward with LoRA fused operations (~5500 lines) - moe-sft-tp.hpp: Tensor-parallel wrapper for multi-NUMA - amx/moe-sft-tp.hpp: AMX-specific TP implementation - avx_kernels.hpp: AVX512 SIMD kernels for LoRA GEMM - amx_kernels.hpp: AMX tile kernels for Panel5 rank-outer optimization - worker_pool: RDTSC profiling, Chrome trace output, SFT timer infrastructure - ext_bindings.cpp: SFT MOE pybind bindings (BF16/INT8/INT4 + SkipLoRA variants) Python sft/ submodule (kt_kernel.sft): - base.py: BaseSFTMoEWrapper with buffer management (template method pattern) - amx.py: AMXSFTMoEWrapper (weight loading, C++ task construction) - autograd.py: KTMoEFunction (torch.autograd.Function for distributed training) - layer.py: KTMoELayerWrapper (nn.Module replacing HF MoE layers) - arch.py: MOEArchConfig (Qwen3/DeepSeek/Mixtral architecture detection) - weights.py: Expert weight extraction and checkpoint loading - lora.py: PEFT LoRA adaptation (view buffers, grad buffers, save/load adapter) - wrapper.py: wrap_moe_layers_with_kt_wrapper, load_kt_model, build_kt_device_map - config.py: KTConfig dataclass (DeepSpeed-style opaque config passthrough) - dist_utils.py: Distributed gather/scatter, checkpoint-phase detection Design decisions: - Rank-0-only expert pattern: only rank 0 holds C++ wrapper and expert weights - DeepSpeed-style integration: accelerate keeps only KTransformersPlugin (framework interaction fields), all logic in kt_kernel.sft - Inference isolation: importing kt_kernel does not load sft/ submodule - Old field name compatibility: _get_kt_config() converts kt_xxx→xxx automatically Verified: Qwen3-235B-A22B 4GPU AMXBF16 training, loss converges normally.
926 lines
33 KiB
C++
926 lines
33 KiB
C++
/**
|
|
* @Description :
|
|
* @Author : chenht2022
|
|
* @Date : 2024-07-22 02:03:05
|
|
* @Version : 1.0.0
|
|
* @LastEditors : chenht2022
|
|
* @LastEditTime : 2024-07-25 10:33:34
|
|
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
**/
|
|
|
|
#include "worker_pool.h"
|
|
|
|
#include <hwloc/bitmap.h>
|
|
#include <numa.h>
|
|
#include <numaif.h>
|
|
#include <signal.h>
|
|
#include <x86intrin.h>
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <chrono>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <fstream>
|
|
#include <iomanip>
|
|
#include <mutex>
|
|
#include <numeric>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "hwloc.h"
|
|
|
|
// RDTSC-based timer for lightweight timing
|
|
// Uses CPU timestamp counter instead of system clock for lower overhead
|
|
namespace {
|
|
|
|
// Read CPU timestamp counter (RDTSC)
|
|
inline uint64_t rdtsc_now() { return __rdtsc(); }
|
|
|
|
// Estimate RDTSC cycles for given milliseconds
|
|
// This is calculated once at startup
|
|
static uint64_t g_rdtsc_cycles_per_ms = 0;
|
|
|
|
// Initialize RDTSC frequency by measuring against chrono
|
|
static uint64_t init_rdtsc_frequency() {
|
|
auto start_chrono = std::chrono::high_resolution_clock::now();
|
|
uint64_t start_rdtsc = rdtsc_now();
|
|
|
|
// Busy wait for ~10ms to calibrate
|
|
while (true) {
|
|
auto now = std::chrono::high_resolution_clock::now();
|
|
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - start_chrono).count();
|
|
if (elapsed >= 10) break;
|
|
}
|
|
|
|
uint64_t end_rdtsc = rdtsc_now();
|
|
auto end_chrono = std::chrono::high_resolution_clock::now();
|
|
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end_chrono - start_chrono).count();
|
|
|
|
if (elapsed_ms > 0) {
|
|
return (end_rdtsc - start_rdtsc) / elapsed_ms;
|
|
}
|
|
// Fallback: assume 2.5 GHz CPU
|
|
return 2500000;
|
|
}
|
|
|
|
// Get cycles per millisecond (lazy initialization)
|
|
inline uint64_t get_rdtsc_cycles_per_ms() {
|
|
if (g_rdtsc_cycles_per_ms == 0) {
|
|
g_rdtsc_cycles_per_ms = init_rdtsc_frequency();
|
|
}
|
|
return g_rdtsc_cycles_per_ms;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// =====================================================
|
|
// Global per-thread timing for SFT MOE forward/backward
|
|
// Collects timing from InNumaPool worker threads
|
|
// =====================================================
|
|
#ifndef SFT_TIMER_DISABLED
|
|
namespace sft_timer {
|
|
|
|
constexpr int MAX_THREADS = 256;
|
|
static uint64_t forward_rt[MAX_THREADS] = {0};
|
|
static uint64_t backward_rt[MAX_THREADS] = {0};
|
|
static int forward_tasks[MAX_THREADS] = {0};
|
|
static int backward_tasks[MAX_THREADS] = {0};
|
|
static int forward_threads = 0;
|
|
static int backward_threads = 0;
|
|
|
|
inline double ticks_to_ms(uint64_t cycles) { return (double)cycles / get_rdtsc_cycles_per_ms(); }
|
|
|
|
// =====================================================
|
|
// Chrome Trace Event Format support
|
|
// =====================================================
|
|
struct TraceEvent {
|
|
std::string name; // event name (op_name)
|
|
std::string cat; // category
|
|
char ph; // phase: 'X' for complete event, 'B' for begin, 'E' for end
|
|
double ts; // timestamp in microseconds (with ns precision via decimals)
|
|
double dur; // duration in microseconds (with ns precision via decimals)
|
|
int pid; // process id (numa_id)
|
|
int tid; // thread id
|
|
int task_count; // number of tasks processed
|
|
std::string args_json; // optional custom args JSON (for kernel traces)
|
|
};
|
|
|
|
static std::vector<TraceEvent> g_trace_events;
|
|
static std::mutex g_trace_mutex;
|
|
static uint64_t g_trace_start_time = 0; // baseline timestamp (RDTSC)
|
|
static double g_trace_start_epoch_us = 0.0; // wall-clock epoch time in microseconds
|
|
static std::string g_trace_output_path = "sft_trace.json";
|
|
|
|
// Thread-safe initialization using std::call_once
|
|
static std::once_flag g_trace_init_flag;
|
|
|
|
// Forward declaration for atexit registration.
|
|
static void write_trace_to_file();
|
|
|
|
// Initialize trace start time (thread-safe)
|
|
static void init_trace() {
|
|
std::call_once(g_trace_init_flag, []() {
|
|
g_trace_start_time = rdtsc_now();
|
|
// Record wall-clock epoch time for cross-process trace alignment
|
|
auto now_wall = std::chrono::system_clock::now();
|
|
auto epoch_us = std::chrono::duration_cast<std::chrono::microseconds>(now_wall.time_since_epoch()).count();
|
|
g_trace_start_epoch_us = static_cast<double>(epoch_us);
|
|
// Check for custom output path from environment
|
|
const char* env_path = std::getenv("SFT_TRACE_PATH");
|
|
if (env_path && env_path[0] != '\0') {
|
|
g_trace_output_path = env_path;
|
|
}
|
|
// Flush trace on normal exit before static destructors run.
|
|
std::atexit(write_trace_to_file);
|
|
});
|
|
}
|
|
|
|
// Convert RDTSC cycles to microseconds with nanosecond precision (as double)
|
|
// Chrome tracing uses microseconds but supports fractional values for sub-us precision
|
|
static double cycles_to_us(uint64_t cycles) {
|
|
// cycles_per_ms * 1000 = cycles_per_us
|
|
// cycles / cycles_per_us = microseconds
|
|
// Using 1e6 for cycles_per_ms -> cycles_per_s, then divide to get us with ns precision
|
|
double cycles_per_us = get_rdtsc_cycles_per_ms() / 1000.0;
|
|
return static_cast<double>(cycles) / cycles_per_us;
|
|
}
|
|
|
|
// Add trace events for an operation using absolute timestamps
|
|
static void add_trace_events(const char* op_name, int numa_id, int thread_count, const uint64_t* start_ts_arr,
|
|
const uint64_t* end_ts_arr, const int* tasks) {
|
|
init_trace();
|
|
|
|
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
|
|
|
for (int i = 0; i < thread_count; i++) {
|
|
// Convert absolute RDTSC timestamps to relative microseconds from trace start
|
|
double start_us = (start_ts_arr[i] > g_trace_start_time) ? cycles_to_us(start_ts_arr[i] - g_trace_start_time) : 0.0;
|
|
double end_us = (end_ts_arr[i] > g_trace_start_time) ? cycles_to_us(end_ts_arr[i] - g_trace_start_time) : 0.0;
|
|
double dur_us = end_us - start_us;
|
|
if (dur_us < 0) dur_us = 0;
|
|
|
|
TraceEvent ev;
|
|
ev.name = op_name;
|
|
ev.cat = "sft_op";
|
|
ev.ph = 'X'; // Complete event
|
|
ev.ts = start_us;
|
|
ev.dur = dur_us;
|
|
ev.pid = numa_id;
|
|
ev.tid = i;
|
|
ev.task_count = tasks[i];
|
|
|
|
g_trace_events.push_back(ev);
|
|
}
|
|
}
|
|
|
|
// Write trace events to JSON file (Chrome Trace Event Format)
|
|
static void write_trace_to_file() {
|
|
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
|
|
|
if (g_trace_events.empty()) {
|
|
return;
|
|
}
|
|
|
|
// Sort events by (pid, tid, ts) to fix overlap issues in Chrome trace viewer
|
|
// Events from same thread should be ordered by start time
|
|
std::sort(g_trace_events.begin(), g_trace_events.end(), [](const TraceEvent& a, const TraceEvent& b) {
|
|
if (a.pid != b.pid) return a.pid < b.pid;
|
|
if (a.tid != b.tid) return a.tid < b.tid;
|
|
return a.ts < b.ts;
|
|
});
|
|
|
|
std::ofstream ofs(g_trace_output_path);
|
|
if (!ofs.is_open()) {
|
|
fprintf(stderr, "sft_timer: Failed to open trace file: %s\n", g_trace_output_path.c_str());
|
|
return;
|
|
}
|
|
|
|
// Use fixed precision for nanosecond accuracy (3 decimal places in microseconds = nanoseconds)
|
|
ofs << std::fixed << std::setprecision(3);
|
|
|
|
ofs << "{\n";
|
|
ofs << " \"traceEvents\": [\n";
|
|
|
|
for (size_t i = 0; i < g_trace_events.size(); i++) {
|
|
const auto& ev = g_trace_events[i];
|
|
ofs << " {";
|
|
ofs << "\"name\":\"" << ev.name << "\",";
|
|
ofs << "\"cat\":\"" << ev.cat << "\",";
|
|
ofs << "\"ph\":\"" << ev.ph << "\",";
|
|
ofs << "\"ts\":" << ev.ts << ",";
|
|
ofs << "\"dur\":" << ev.dur << ",";
|
|
ofs << "\"pid\":" << ev.pid << ",";
|
|
ofs << "\"tid\":" << ev.tid << ",";
|
|
if (!ev.args_json.empty()) {
|
|
ofs << "\"args\":" << ev.args_json;
|
|
} else {
|
|
ofs << "\"args\":{\"task_count\":" << ev.task_count << "}";
|
|
}
|
|
ofs << "}";
|
|
if (i < g_trace_events.size() - 1) {
|
|
ofs << ",";
|
|
}
|
|
ofs << "\n";
|
|
}
|
|
|
|
ofs << " ],\n";
|
|
ofs << " \"metadata\": {\"start_epoch_us\": " << std::setprecision(0) << g_trace_start_epoch_us << "},\n";
|
|
ofs << std::setprecision(3);
|
|
ofs << " \"displayTimeUnit\": \"ns\"\n";
|
|
ofs << "}\n";
|
|
|
|
ofs.close();
|
|
fprintf(stderr, "sft_timer: Trace written to %s (%zu events)\n", g_trace_output_path.c_str(), g_trace_events.size());
|
|
}
|
|
|
|
// Signal handler for SIGTERM
|
|
static void sigterm_handler(int sig) {
|
|
fprintf(stderr, "sft_timer: Received signal %d, writing trace...\n", sig);
|
|
write_trace_to_file();
|
|
// Re-raise the signal with default handler to allow normal termination
|
|
signal(sig, SIG_DFL);
|
|
raise(sig);
|
|
}
|
|
|
|
// Register signal handlers
|
|
static void register_signal_handlers() {
|
|
static bool registered = false;
|
|
if (!registered) {
|
|
signal(SIGTERM, sigterm_handler);
|
|
signal(SIGINT, sigterm_handler);
|
|
registered = true;
|
|
}
|
|
}
|
|
|
|
void print_rt(FILE* out, const char* name, uint64_t* rt, int* tasks, int rt_threads) {
|
|
if (rt_threads <= 0) return;
|
|
FILE* output = out ? out : stderr;
|
|
auto max_val = *std::max_element(rt, rt + rt_threads);
|
|
auto min_val = *std::min_element(rt, rt + rt_threads);
|
|
uint64_t sum = std::accumulate(rt, rt + rt_threads, (uint64_t)0);
|
|
int total_tasks = std::accumulate(tasks, tasks + rt_threads, 0);
|
|
|
|
// Sort to find 20% and 80% percentile thresholds
|
|
std::vector<uint64_t> sorted(rt, rt + rt_threads);
|
|
std::sort(sorted.begin(), sorted.end());
|
|
int p20_idx = rt_threads * 20 / 100;
|
|
int p80_idx = rt_threads * 80 / 100;
|
|
uint64_t p20_threshold = sorted[p20_idx]; // Fast threshold (top 20%)
|
|
uint64_t p80_threshold = sorted[p80_idx]; // Slow threshold (bottom 20%)
|
|
|
|
// ANSI color codes
|
|
const char* GREEN = "\033[32m";
|
|
const char* RED = "\033[31m";
|
|
const char* RESET = "\033[0m";
|
|
|
|
// Line 1: time
|
|
fprintf(output, "%30s max %.3f min %.3f avg %.3f : ", name, ticks_to_ms(max_val), ticks_to_ms(min_val),
|
|
ticks_to_ms(sum / rt_threads));
|
|
for (int i = 0; i < rt_threads; i++) {
|
|
if (rt[i] <= p20_threshold) {
|
|
fprintf(output, "%s%.3f%s ", GREEN, ticks_to_ms(rt[i]), RESET);
|
|
} else if (rt[i] >= p80_threshold) {
|
|
fprintf(output, "%s%.3f%s ", RED, ticks_to_ms(rt[i]), RESET);
|
|
} else {
|
|
fprintf(output, "%.3f ", ticks_to_ms(rt[i]));
|
|
}
|
|
}
|
|
fprintf(output, "\n");
|
|
|
|
// Line 2: task count
|
|
fprintf(output, "%30s total %d : ", "tasks", total_tasks);
|
|
for (int i = 0; i < rt_threads; i++) {
|
|
if (rt[i] <= p20_threshold) {
|
|
fprintf(output, "%s%d%s ", GREEN, tasks[i], RESET);
|
|
} else if (rt[i] >= p80_threshold) {
|
|
fprintf(output, "%s%d%s ", RED, tasks[i], RESET);
|
|
} else {
|
|
fprintf(output, "%d ", tasks[i]);
|
|
}
|
|
}
|
|
fprintf(output, "\n");
|
|
}
|
|
|
|
void reset_forward() {
|
|
std::fill(forward_rt, forward_rt + MAX_THREADS, 0);
|
|
std::fill(forward_tasks, forward_tasks + MAX_THREADS, 0);
|
|
forward_threads = 0;
|
|
}
|
|
|
|
void reset_backward() {
|
|
std::fill(backward_rt, backward_rt + MAX_THREADS, 0);
|
|
std::fill(backward_tasks, backward_tasks + MAX_THREADS, 0);
|
|
backward_threads = 0;
|
|
}
|
|
|
|
void collect_forward(InNumaPool* pool) {
|
|
int n = pool->get_worker_count();
|
|
for (int i = 0; i < n && forward_threads < MAX_THREADS; i++) {
|
|
forward_rt[forward_threads] = pool->get_thread_cycles(i);
|
|
forward_tasks[forward_threads] = pool->get_thread_task_count(i);
|
|
forward_threads++;
|
|
}
|
|
}
|
|
|
|
void collect_backward(InNumaPool* pool) {
|
|
int n = pool->get_worker_count();
|
|
for (int i = 0; i < n && backward_threads < MAX_THREADS; i++) {
|
|
backward_rt[backward_threads] = pool->get_thread_cycles(i);
|
|
backward_tasks[backward_threads] = pool->get_thread_task_count(i);
|
|
backward_threads++;
|
|
}
|
|
}
|
|
|
|
void print_forward() { print_rt(stderr, "forward", forward_rt, forward_tasks, forward_threads); }
|
|
void print_backward(const char* name) { print_rt(stderr, name, backward_rt, backward_tasks, backward_threads); }
|
|
|
|
void print_op_stats(InNumaPool* pool, const char* op_name) {
|
|
if (pool == nullptr || op_name == nullptr || op_name[0] == '\0') {
|
|
return;
|
|
}
|
|
int n = pool->get_worker_count();
|
|
if (n <= 0) {
|
|
return;
|
|
}
|
|
|
|
// Ensure signal handlers are registered on first call
|
|
static bool handlers_registered = false;
|
|
if (!handlers_registered) {
|
|
register_signal_handlers();
|
|
handlers_registered = true;
|
|
}
|
|
|
|
FILE* output = stderr;
|
|
int numa_id = pool->get_numa_id();
|
|
// if (numa_id == 0) {
|
|
// output = stdout;
|
|
// } else if (numa_id == 1) {
|
|
// output = stderr;
|
|
// }
|
|
std::vector<uint64_t> rt(n);
|
|
std::vector<uint64_t> start_ts(n);
|
|
std::vector<uint64_t> end_ts(n);
|
|
std::vector<int> tasks(n);
|
|
for (int i = 0; i < n; i++) {
|
|
rt[i] = pool->get_thread_cycles(i);
|
|
tasks[i] = pool->get_thread_task_count(i);
|
|
start_ts[i] = pool->get_thread_start_ts(i);
|
|
end_ts[i] = pool->get_thread_end_ts(i);
|
|
}
|
|
// print_rt(output, op_name, rt.data(), tasks.data(), n);
|
|
|
|
// Save trace data to memory for later export
|
|
add_trace_events(op_name, numa_id, n, start_ts.data(), end_ts.data(), tasks.data());
|
|
}
|
|
|
|
// =====================================================
|
|
// Kernel-level tracing API implementation
|
|
// =====================================================
|
|
|
|
uint64_t get_trace_timestamp() { return rdtsc_now(); }
|
|
|
|
void add_kernel_trace(const char* name, uint64_t start_ts, uint64_t end_ts, int numa_id, int thread_id,
|
|
const char* args) {
|
|
init_trace();
|
|
|
|
// Convert absolute RDTSC timestamps to relative microseconds from trace start
|
|
double start_us = (start_ts > g_trace_start_time) ? cycles_to_us(start_ts - g_trace_start_time) : 0.0;
|
|
double end_us = (end_ts > g_trace_start_time) ? cycles_to_us(end_ts - g_trace_start_time) : 0.0;
|
|
double dur_us = end_us - start_us;
|
|
if (dur_us < 0) dur_us = 0;
|
|
|
|
std::lock_guard<std::mutex> lock(g_trace_mutex);
|
|
|
|
TraceEvent ev;
|
|
ev.name = name;
|
|
ev.cat = "kernel";
|
|
ev.ph = 'X'; // Complete event
|
|
ev.ts = start_us;
|
|
ev.dur = dur_us;
|
|
ev.pid = numa_id;
|
|
ev.tid = thread_id;
|
|
ev.task_count = 0; // Not applicable for kernel traces
|
|
if (args != nullptr && args[0] != '\0') {
|
|
ev.args_json = args;
|
|
}
|
|
|
|
g_trace_events.push_back(ev);
|
|
}
|
|
|
|
} // namespace sft_timer
|
|
#endif // SFT_TIMER_DISABLED
|
|
|
|
// Intel ITT API for profiler integration (VTune, etc.)
|
|
// Allows profilers to identify spin-wait regions
|
|
#ifdef USE_ITT_NOTIFY
|
|
#include <ittnotify.h>
|
|
static __itt_domain* g_itt_domain = nullptr;
|
|
static __itt_string_handle* g_itt_spin_wait = nullptr;
|
|
|
|
static void init_itt() {
|
|
if (g_itt_domain == nullptr) {
|
|
g_itt_domain = __itt_domain_create("WorkerPool");
|
|
g_itt_spin_wait = __itt_string_handle_create("SpinWait");
|
|
}
|
|
}
|
|
|
|
#define ITT_SYNC_PREPARE(addr) __itt_sync_prepare(addr)
|
|
#define ITT_SYNC_CANCEL(addr) __itt_sync_cancel(addr)
|
|
#define ITT_SYNC_ACQUIRED(addr) __itt_sync_acquired(addr)
|
|
#else
|
|
#define ITT_SYNC_PREPARE(addr) ((void)0)
|
|
#define ITT_SYNC_CANCEL(addr) ((void)0)
|
|
#define ITT_SYNC_ACQUIRED(addr) ((void)0)
|
|
static void init_itt() {}
|
|
#endif
|
|
|
|
thread_local int WorkerPool::thread_local_id = -1;
|
|
|
|
InNumaPool::InNumaPool(int max_thread_num) {
|
|
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
|
|
numa_id_ = numa_node_of_cpu(sched_getcpu());
|
|
total_worker_count = max_thread_num;
|
|
block_size_ = 0;
|
|
set_restricted_worker_count(total_worker_count);
|
|
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
|
|
for (int i = 0; i < total_worker_count; i++) {
|
|
thread_state_[i].status.store(ThreadStatus::WAITING, std::memory_order_release);
|
|
}
|
|
workers_.resize(total_worker_count);
|
|
for (int i = 1; i < total_worker_count; i++) {
|
|
workers_[i] = std::thread(&InNumaPool::worker_thread, this, i, -1);
|
|
}
|
|
}
|
|
|
|
InNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {
|
|
printf("===========In NumaPool============\n");
|
|
hwloc_topology_t topology;
|
|
hwloc_obj_t numa_obj, core_obj;
|
|
hwloc_bitmap_t cpuset;
|
|
hwloc_topology_init(&topology);
|
|
hwloc_topology_load(topology);
|
|
printf("In Numa Worker Pool at NUMA %d, %d threads\n", numa_node_of_cpu(sched_getcpu()), max_thread_num);
|
|
numa_id_ = numa_id;
|
|
total_worker_count = max_thread_num;
|
|
block_size_ = 0;
|
|
set_restricted_worker_count(total_worker_count);
|
|
thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);
|
|
for (int i = 0; i < total_worker_count; i++) {
|
|
thread_state_[i].status.store(ThreadStatus::WAITING, std::memory_order_release);
|
|
}
|
|
workers_.resize(total_worker_count);
|
|
for (int i = 1; i < total_worker_count; i++) {
|
|
workers_[i] = std::thread(&InNumaPool::worker_thread, this, i, numa_id);
|
|
// set the thread name as: "numa_(numa_id)_t_(i+threads_id_start)"
|
|
std::string thread_name = "numa_" + std::to_string(numa_id) + "_t_" + std::to_string(i + threads_id_start);
|
|
pthread_t native_handle = workers_[i].native_handle();
|
|
auto res_set_name = pthread_setname_np(native_handle, thread_name.c_str());
|
|
if (res_set_name != 0) {
|
|
fprintf(stderr, "Failed to set thread name: %s\n", strerror(res_set_name));
|
|
}
|
|
// 检查线程是否成功命名
|
|
char name[16];
|
|
pthread_getname_np(native_handle, name, sizeof(name));
|
|
if (strcmp(name, thread_name.c_str()) == 0) {
|
|
// printf("Thread name set successfully: %s\n", name);
|
|
} else {
|
|
// printf("Failed to set thread name: %s\n", name);
|
|
}
|
|
// Set the thread affinity to the specified NUMA node's CPU
|
|
numa_obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_NUMANODE, numa_id);
|
|
if (!numa_obj) {
|
|
fprintf(stderr, "NUMA node %d not found\n", numa_id);
|
|
// throw std::runtime_error("NUMA node not found");
|
|
continue;
|
|
}
|
|
core_obj = hwloc_get_obj_inside_cpuset_by_type(topology, numa_obj->cpuset, HWLOC_OBJ_CORE, i + threads_id_start);
|
|
if (!core_obj) {
|
|
fprintf(stderr, "Core %d inside NUMA node %d not found\n", i, numa_id);
|
|
// throw std::runtime_error("Core not found inside NUMA node");
|
|
continue;
|
|
}
|
|
cpuset = hwloc_bitmap_alloc();
|
|
hwloc_bitmap_copy(cpuset, core_obj->cpuset);
|
|
hwloc_bitmap_singlify(cpuset);
|
|
auto res = hwloc_set_thread_cpubind(topology, native_handle, cpuset, HWLOC_CPUBIND_STRICT);
|
|
if (res != 0) {
|
|
fprintf(stderr, "Failed to set thread CPU binding: %s\n", strerror(errno));
|
|
}
|
|
}
|
|
}
|
|
|
|
InNumaPool::~InNumaPool() {
|
|
for (int i = 0; i < total_worker_count; i++) {
|
|
thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);
|
|
}
|
|
for (int i = 0; i < total_worker_count; i++) {
|
|
if (workers_[i].joinable()) {
|
|
workers_[i].join();
|
|
}
|
|
}
|
|
}
|
|
|
|
int InNumaPool::get_thread_num() {
|
|
throw std::runtime_error("Deprecated");
|
|
return total_worker_count;
|
|
}
|
|
|
|
void InNumaPool::set_restricted_worker_count(int count) { restricted_worker_count = count; }
|
|
|
|
void InNumaPool::wait() {
|
|
for (int i = 0; i < worker_count; i++) {
|
|
while (thread_state_[i].status.load(std::memory_order_acquire) == ThreadStatus::WORKING) {
|
|
}
|
|
}
|
|
|
|
#ifdef PROFILE_BALANCE
|
|
size_t max_time = 0;
|
|
size_t min_time = thread_state_[0].finish_ns;
|
|
size_t sum = 0;
|
|
for (int i = 0; i < worker_count; i++) {
|
|
sum += thread_state_[i].finish_ns;
|
|
max_time = std::max(max_time, thread_state_[i].finish_ns);
|
|
min_time = std::min(min_time, thread_state_[i].finish_ns);
|
|
}
|
|
double balance = 1.0 * sum / (max_time * worker_count);
|
|
printf("max_time: %ld, min_time: %ld, sum_time: %ld, balance: %f\n", max_time, min_time, sum, balance);
|
|
|
|
#endif
|
|
}
|
|
|
|
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
|
|
int block_size, bool async) {
|
|
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size);
|
|
}
|
|
|
|
void InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
|
|
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
|
|
const char* task_name, int block_size, bool async) {
|
|
bool has_name = task_name != nullptr && task_name[0] != '\0';
|
|
if (has_name) {
|
|
reset_counters();
|
|
}
|
|
do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func, block_size);
|
|
if (!async) wait();
|
|
if (has_name) {
|
|
sft_timer::print_op_stats(this, task_name);
|
|
}
|
|
}
|
|
|
|
void InNumaPool::do_work_stealing_job_async(int task_num, std::function<void(int)> init_func,
|
|
std::function<void(int)> compute_func,
|
|
std::function<void(int)> finalize_func, int block_size) {
|
|
init_func_ = init_func;
|
|
compute_func_ = compute_func;
|
|
finalize_func_ = finalize_func;
|
|
block_size_ = block_size;
|
|
worker_count = std::min(restricted_worker_count, task_num);
|
|
curr_.store(0, std::memory_order_release);
|
|
end_ = task_num;
|
|
|
|
for (int i = 0; i < worker_count; i++) {
|
|
thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);
|
|
}
|
|
WorkerPool::thread_local_id = 0;
|
|
process_tasks(0);
|
|
}
|
|
|
|
void InNumaPool::process_tasks(int thread_id) {
|
|
uint64_t start_cycles = rdtsc_now();
|
|
auto& s = thread_state_[thread_id];
|
|
int local_task_count = 0;
|
|
|
|
// Record absolute start timestamp
|
|
s.start_ts = start_cycles;
|
|
|
|
if (init_func_ != nullptr) {
|
|
init_func_(thread_id);
|
|
}
|
|
|
|
// omp-guided-style work scheduling
|
|
while (true) {
|
|
int old = curr_.load(std::memory_order_relaxed);
|
|
int rem = end_ - old;
|
|
if (rem <= 0) {
|
|
break;
|
|
}
|
|
|
|
int block = 0;
|
|
if (block_size_ > 0) {
|
|
block = std::min(block_size_, rem);
|
|
} else {
|
|
block = (rem + worker_count - 1) / worker_count;
|
|
}
|
|
block = 1;
|
|
int task_id = curr_.fetch_add(block, std::memory_order_acq_rel);
|
|
if (task_id >= end_) {
|
|
break;
|
|
}
|
|
|
|
for (int i = 0; i < block; i++) {
|
|
if (task_id + i >= end_) {
|
|
break;
|
|
}
|
|
compute_func_(task_id + i);
|
|
local_task_count++;
|
|
}
|
|
}
|
|
|
|
if (finalize_func_ != nullptr) {
|
|
finalize_func_(thread_id);
|
|
}
|
|
|
|
// IMPORTANT: Update timing BEFORE setting status to WAITING
|
|
// The release semantics of status.store() ensures all prior writes are visible
|
|
uint64_t end_cycles = rdtsc_now();
|
|
s.finish_cycles = end_cycles - start_cycles;
|
|
s.task_count = local_task_count;
|
|
s.end_ts = end_cycles;
|
|
|
|
// Signal completion - release ensures timing writes are visible to wait()
|
|
s.status.store(ThreadStatus::WAITING, std::memory_order_release);
|
|
}
|
|
|
|
void InNumaPool::worker_thread(int thread_id, int numa_id) {
|
|
if (numa_id >= 0) {
|
|
set_memory_to_numa(numa_id);
|
|
}
|
|
init_itt(); // Initialize ITT if enabled
|
|
// Use RDTSC for lightweight timing instead of std::chrono
|
|
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
|
|
uint64_t start = rdtsc_now();
|
|
WorkerPool::thread_local_id = thread_id; // 设置线程本地变量
|
|
while (true) {
|
|
ITT_SYNC_PREPARE(&thread_state_[thread_id].status); // Signal profiler: about to spin-wait
|
|
ThreadStatus status = thread_state_[thread_id].status.load(std::memory_order_acquire);
|
|
if (status == ThreadStatus::WORKING) {
|
|
ITT_SYNC_ACQUIRED(&thread_state_[thread_id].status); // Signal profiler: acquired work
|
|
process_tasks(thread_id);
|
|
start = rdtsc_now();
|
|
} else if (status == ThreadStatus::WAITING) {
|
|
// PAUSE instruction hints to CPU this is a spin-wait loop
|
|
_mm_pause();
|
|
uint64_t now = rdtsc_now();
|
|
uint64_t elapsed_cycles = now - start;
|
|
if (elapsed_cycles > sleep_threshold_cycles) {
|
|
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: going to sleep
|
|
std::this_thread::sleep_for(std::chrono::microseconds(100));
|
|
}
|
|
} else if (status == ThreadStatus::EXIT) {
|
|
ITT_SYNC_CANCEL(&thread_state_[thread_id].status); // Signal profiler: exiting
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
NumaJobDistributor::NumaJobDistributor(int numa_count) {
|
|
std::vector<int> numa_ids;
|
|
for (int i = 0; i < numa_count; i++) {
|
|
numa_ids.push_back(i);
|
|
}
|
|
init(numa_ids);
|
|
}
|
|
|
|
NumaJobDistributor::NumaJobDistributor(std::vector<int> numa_ids) { init(numa_ids); }
|
|
NumaJobDistributor::NumaJobDistributor(std::vector<int> numa_ids, std::vector<int> thread_count) {
|
|
init(numa_ids, thread_count);
|
|
}
|
|
|
|
void NumaJobDistributor::init(std::vector<int> numa_ids) {
|
|
this->numa_count = numa_ids.size();
|
|
this->ready_bar = std::unique_ptr<std::barrier<>>(new std::barrier<>(numa_count + 1));
|
|
this->numa_ids = numa_ids;
|
|
for (size_t i = 0; i < numa_count; i++) {
|
|
status.push_back(nullptr);
|
|
}
|
|
|
|
workers.resize(numa_count);
|
|
for (int i = 0; i < numa_count; i++) {
|
|
std::thread([this, i]() { workers[i] = std::thread(&NumaJobDistributor::worker_thread, this, i); }).join();
|
|
}
|
|
ready_bar->arrive_and_wait();
|
|
}
|
|
|
|
void NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread_count) {
|
|
hwloc_topology_t topology;
|
|
hwloc_obj_t numa_obj, core_obj;
|
|
hwloc_bitmap_t cpuset;
|
|
hwloc_topology_init(&topology);
|
|
hwloc_topology_load(topology);
|
|
|
|
this->numa_count = numa_ids.size();
|
|
this->ready_bar = std::unique_ptr<std::barrier<>>(new std::barrier<>(numa_count + 1));
|
|
this->numa_ids = numa_ids;
|
|
for (size_t i = 0; i < numa_count; i++) {
|
|
status.push_back(nullptr);
|
|
}
|
|
|
|
workers.resize(numa_count);
|
|
std::vector<int> numa_threads_count(numa_count, 0);
|
|
for (int i = 0; i < numa_count; i++) {
|
|
workers[i] = std::thread(&NumaJobDistributor::worker_thread, this, i);
|
|
auto this_numa = numa_ids[i];
|
|
auto start_id = numa_threads_count[this_numa];
|
|
// set the thread name as: "worker_numa_(numa_id)_main_start_id(0)"
|
|
// printf("nuam_id %d, start_id %d\n", this_numa, start_id);
|
|
std::string thread_name = "numa_" + std::to_string(numa_ids[i]) + "_m_" + std::to_string(start_id);
|
|
pthread_t native_handle = workers[i].native_handle();
|
|
pthread_setname_np(native_handle, thread_name.c_str());
|
|
// Set the thread affinity to the specified NUMA node's CPU (0)
|
|
numa_obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_NUMANODE, this_numa);
|
|
if (!numa_obj) {
|
|
fprintf(stderr, "NUMA node %d not found\n", this_numa);
|
|
// throw std::runtime_error("NUMA node not found");
|
|
continue;
|
|
}
|
|
core_obj = hwloc_get_obj_inside_cpuset_by_type(topology, numa_obj->cpuset, HWLOC_OBJ_CORE, start_id);
|
|
if (!core_obj) {
|
|
fprintf(stderr, "Core %d inside NUMA node %d not found\n", 0, this_numa);
|
|
// throw std::runtime_error("Core not found inside NUMA node");
|
|
continue;
|
|
}
|
|
// 精简 cpuset
|
|
auto cpuset_simple = hwloc_bitmap_alloc();
|
|
hwloc_bitmap_copy(cpuset_simple, core_obj->cpuset);
|
|
hwloc_bitmap_singlify(cpuset_simple);
|
|
// 打印绑定的具体的 CPU 物理索引
|
|
unsigned long i_in;
|
|
// hwloc_bitmap_foreach_begin(i_in, cpuset_simple) { printf("Thread %d bound to CPU %ld\n", start_id, i_in); }
|
|
// hwloc_bitmap_foreach_end();
|
|
auto res = hwloc_set_thread_cpubind(topology, native_handle, cpuset_simple, HWLOC_CPUBIND_STRICT);
|
|
if (res != 0) {
|
|
fprintf(stderr, "Failed to set thread CPU binding: %s\n", strerror(errno));
|
|
}
|
|
// 检查线程是否绑定到指定的 核上了
|
|
hwloc_cpuset_t cpuset = hwloc_bitmap_alloc();
|
|
hwloc_get_thread_cpubind(topology, native_handle, cpuset, HWLOC_CPUBIND_THREAD);
|
|
// hwloc_bitmap_foreach_begin(i_in, cpuset) { printf("Thread %d is bound to CPU %ld\n", start_id, i_in); }
|
|
// hwloc_bitmap_foreach_end();
|
|
|
|
numa_threads_count[this_numa] += thread_count[i];
|
|
}
|
|
ready_bar->arrive_and_wait();
|
|
}
|
|
|
|
NumaJobDistributor::~NumaJobDistributor() {
|
|
for (int i = 0; i < numa_count; i++) {
|
|
status[i]->store(ThreadStatus::EXIT, std::memory_order_release);
|
|
}
|
|
for (int i = 0; i < numa_count; i++) {
|
|
if (workers[i].joinable()) {
|
|
workers[i].join();
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef USE_NUMA_JOB_DIRECT_WORK
|
|
|
|
void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
|
this->compute_func = compute_func;
|
|
auto me_numa = numa_node_of_cpu(sched_getcpu());
|
|
for (int i = 0; i < numa_count; i++) {
|
|
if (i == me_numa) continue;
|
|
|
|
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
|
}
|
|
compute_func(me_numa);
|
|
for (int i = 0; i < numa_count; i++) {
|
|
if (i == me_numa) continue;
|
|
|
|
while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
void NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {
|
|
this->compute_func = compute_func;
|
|
for (int i = 0; i < numa_count; i++) {
|
|
status[i]->store(ThreadStatus::WORKING, std::memory_order_release);
|
|
}
|
|
for (int i = 0; i < numa_count; i++) {
|
|
while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
void NumaJobDistributor::worker_thread(int numa_id) {
|
|
init_itt(); // Initialize ITT if enabled
|
|
// Use RDTSC for lightweight timing instead of std::chrono
|
|
const uint64_t sleep_threshold_cycles = get_rdtsc_cycles_per_ms() * 50; // 50ms in cycles
|
|
uint64_t start = rdtsc_now();
|
|
set_memory_to_numa(numa_id);
|
|
status[numa_id] =
|
|
std::move(std::unique_ptr<std::atomic<ThreadStatus>>(new std::atomic<ThreadStatus>(ThreadStatus::WAITING)));
|
|
ready_bar->arrive_and_wait();
|
|
while (true) {
|
|
ITT_SYNC_PREPARE(status[numa_id].get()); // Signal profiler: about to spin-wait
|
|
auto stat = status[numa_id]->load(std::memory_order_acquire);
|
|
if (stat == ThreadStatus::WORKING) {
|
|
ITT_SYNC_ACQUIRED(status[numa_id].get()); // Signal profiler: acquired work
|
|
auto me_numa = numa_node_of_cpu(sched_getcpu());
|
|
// printf("numa work on %d, me %d\n", numa_id, me_numa);
|
|
compute_func(numa_id);
|
|
status[numa_id]->store(ThreadStatus::WAITING, std::memory_order_release);
|
|
start = rdtsc_now();
|
|
} else if (stat == ThreadStatus::WAITING) {
|
|
// PAUSE instruction hints to CPU this is a spin-wait loop
|
|
_mm_pause();
|
|
uint64_t now = rdtsc_now();
|
|
uint64_t elapsed_cycles = now - start;
|
|
if (elapsed_cycles > sleep_threshold_cycles) {
|
|
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: going to sleep
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
|
}
|
|
} else if (stat == ThreadStatus::EXIT) {
|
|
ITT_SYNC_CANCEL(status[numa_id].get()); // Signal profiler: exiting
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void WorkerPool::init(WorkerPoolConfig config) {
|
|
printf("WorkerPool[0x%lx] %d subpools, [numa:threads]", (intptr_t)this, config.subpool_count);
|
|
for (int i = 0; i < config.subpool_count; i++) {
|
|
printf("[%d:%d] ", config.subpool_numa_map[i], config.subpool_thread_count[i]);
|
|
}
|
|
printf("\n");
|
|
|
|
for (int i = 0; i < config.subpool_count; i++) {
|
|
numa_worker_pools.push_back(nullptr);
|
|
}
|
|
std::vector<int> numa_threads_count(config.subpool_count, 0);
|
|
for (int i = 0; i < config.subpool_count; i++) {
|
|
auto this_numa = config.subpool_numa_map[i];
|
|
auto this_thread_count = config.subpool_thread_count[i];
|
|
auto this_thread_id_start = numa_threads_count[this_numa];
|
|
std::thread([this, i, this_numa, this_thread_count, this_thread_id_start]() {
|
|
set_to_numa(this_numa);
|
|
numa_worker_pools[i] =
|
|
std::move(std::unique_ptr<InNumaPool>(new InNumaPool(this_thread_count, this_numa, this_thread_id_start)));
|
|
// numa_worker_pools[i] = std::move(std::unique_ptr<InNumaPool>(new InNumaPool(this_thread_count)));
|
|
}).join();
|
|
numa_threads_count[this_numa] += this_thread_count;
|
|
}
|
|
|
|
distributor = std::move(std::unique_ptr<NumaJobDistributor>(
|
|
new NumaJobDistributor(config.subpool_numa_map, config.subpool_thread_count)));
|
|
// distributor = std::move(std::unique_ptr<NumaJobDistributor>(new NumaJobDistributor(config.subpool_numa_map)));
|
|
}
|
|
|
|
WorkerPool::WorkerPool(WorkerPoolConfig config) : config(config) { init(config); }
|
|
|
|
WorkerPool::WorkerPool(int total_threads) {
|
|
config.subpool_count = numa_num_configured_nodes();
|
|
config.subpool_numa_map.resize(config.subpool_count);
|
|
config.subpool_thread_count.resize(config.subpool_count);
|
|
for (int i = 0; i < config.subpool_count; i++) {
|
|
config.subpool_numa_map[i] = i;
|
|
config.subpool_thread_count[i] = total_threads / config.subpool_count;
|
|
}
|
|
init(config);
|
|
}
|
|
|
|
WorkerPool::WorkerPool(int total_threads, int single_numa_id) {
|
|
set_to_numa(single_numa_id);
|
|
config.subpool_count = numa_num_configured_nodes();
|
|
config.subpool_numa_map.resize(config.subpool_count);
|
|
config.subpool_thread_count.resize(config.subpool_count);
|
|
for (int i = 0; i < config.subpool_count; i++) {
|
|
config.subpool_numa_map[i] = single_numa_id;
|
|
config.subpool_thread_count[i] = total_threads / config.subpool_count;
|
|
}
|
|
init(config);
|
|
}
|
|
|
|
WorkerPool::~WorkerPool() {}
|
|
|
|
int WorkerPool::get_thread_num() { return total_thread_count; }
|
|
|
|
void WorkerPool::set_restricted_worker_count(int count) {
|
|
for (int i = 0; i < numa_count; i++) {
|
|
numa_worker_pools[i]->set_restricted_worker_count(threads_per_numa);
|
|
}
|
|
}
|
|
|
|
InNumaPool* WorkerPool::get_subpool(int numa_id) { return numa_worker_pools[numa_id].get(); }
|
|
|
|
NumaJobDistributor* WorkerPool::dispense_backend() { return distributor.get(); }
|
|
|
|
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,
|
|
std::function<void(int)> compute_func, std::function<void(int)> finalize_func,
|
|
const char* task_name, int block_size, bool async) {
|
|
numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func, task_name, block_size,
|
|
async);
|
|
}
|
|
|
|
void WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func, const char* task_name,
|
|
int block_size, bool async) {
|
|
do_work_stealing_job(task_num, nullptr, compute_func, nullptr, task_name, block_size, async);
|
|
}
|
|
|
|
void WorkerPool::wait() { numa_worker_pools[0]->wait(); }
|