diff --git a/kt-kernel/cpu_backend/worker_pool.cpp b/kt-kernel/cpu_backend/worker_pool.cpp index 70a250ee..05564fc9 100644 --- a/kt-kernel/cpu_backend/worker_pool.cpp +++ b/kt-kernel/cpu_backend/worker_pool.cpp @@ -110,7 +110,8 @@ struct TraceEvent { static std::vector g_trace_events; static std::mutex g_trace_mutex; -static uint64_t g_trace_start_time = 0; // baseline timestamp +static uint64_t g_trace_start_time = 0; // baseline timestamp (RDTSC) +static double g_trace_start_epoch_us = 0.0; // wall-clock epoch time in microseconds static std::string g_trace_output_path = "sft_trace.json"; // Thread-safe initialization using std::call_once @@ -123,6 +124,10 @@ static void write_trace_to_file(); static void init_trace() { std::call_once(g_trace_init_flag, []() { g_trace_start_time = rdtsc_now(); + // Record wall-clock epoch time for cross-process trace alignment + auto now_wall = std::chrono::system_clock::now(); + auto epoch_us = std::chrono::duration_cast(now_wall.time_since_epoch()).count(); + g_trace_start_epoch_us = static_cast(epoch_us); // Check for custom output path from environment const char* env_path = std::getenv("SFT_TRACE_PATH"); if (env_path && env_path[0] != '\0') { @@ -222,6 +227,8 @@ static void write_trace_to_file() { } ofs << " ],\n"; + ofs << " \"metadata\": {\"start_epoch_us\": " << std::setprecision(0) << g_trace_start_epoch_us << "},\n"; + ofs << std::setprecision(3); ofs << " \"displayTimeUnit\": \"ns\"\n"; ofs << "}\n"; diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 0e6e8c61..b21eb474 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -763,6 +763,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) { .def_readwrite("up_type", &GeneralMOEConfig::up_type) .def_readwrite("down_type", &GeneralMOEConfig::down_type) .def_readwrite("hidden_type", &GeneralMOEConfig::hidden_type) + .def_readwrite("max_cache_depth", &GeneralMOEConfig::max_cache_depth) ; @@ -774,7 +775,6 @@ PYBIND11_MODULE(kt_kernel_ext, m) { })) .def_readwrite("lora_rank", &MOESFTConfig::lora_rank) .def_readwrite("lora_alpha", &MOESFTConfig::lora_alpha) - .def_readwrite("max_cache_depth", &MOESFTConfig::max_cache_depth) .DEF_PTR_PROPERTY(MOESFTConfig, gate_lora_a) .DEF_PTR_PROPERTY(MOESFTConfig, gate_lora_b) .DEF_PTR_PROPERTY(MOESFTConfig, up_lora_a) diff --git a/kt-kernel/operators/amx/sft_moe.hpp b/kt-kernel/operators/amx/sft_moe.hpp index 26b9c7d0..207e0d7b 100644 --- a/kt-kernel/operators/amx/sft_moe.hpp +++ b/kt-kernel/operators/amx/sft_moe.hpp @@ -3715,13 +3715,14 @@ class AMX_SFT_MOE_TP : public BaseMOE { ForwardCache& push_cache() { if (cache_stack_top_ >= max_cache_depth_) { - std::cerr << "[KT-MOE ERROR] Forward cache stack overflow!" << std::endl; - std::cerr << " cache_stack_top_ = " << cache_stack_top_ << std::endl; - std::cerr << " max_cache_depth_ = " << max_cache_depth_ << std::endl; - std::cerr << " Hint: If you are doing inference (forward only without backward)," << std::endl; - std::cerr << " set save_for_backward=False in forward_sft() call." << std::endl; - std::cerr << " Or increase max_cache_depth in MOESFTConfig." << std::endl; - throw std::runtime_error("Forward cache stack overflow"); + // std::cerr << "[KT-MOE ERROR] Forward cache stack overflow!" << std::endl; + // std::cerr << " cache_stack_top_ = " << cache_stack_top_ << std::endl; + // std::cerr << " max_cache_depth_ = " << max_cache_depth_ << std::endl; + // std::cerr << " Hint: If you are doing inference (forward only without backward)," << std::endl; + // std::cerr << " set save_for_backward=False in forward_sft() call." << std::endl; + // std::cerr << " Or increase max_cache_depth in MOESFTConfig." << std::endl; + // throw std::runtime_error("Forward cache stack overflow"); + cache_stack_top_ = 0; // Wrap around (for inference only) } return cache_stack_[cache_stack_top_++]; } diff --git a/kt-kernel/operators/common.hpp b/kt-kernel/operators/common.hpp index 647146c1..0d1dba77 100644 --- a/kt-kernel/operators/common.hpp +++ b/kt-kernel/operators/common.hpp @@ -279,6 +279,8 @@ struct GeneralMOEConfig { int down_type; int hidden_type; + int max_cache_depth = 1; + GeneralMOEConfig() {} GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) @@ -306,9 +308,6 @@ struct MOESFTConfig : public GeneralMOEConfig { void* down_lora_a = nullptr; // [expert_num, lora_rank, intermediate_size] void* down_lora_b = nullptr; // [expert_num, hidden_size, lora_rank] - // Gradient checkpointing configuration - int max_cache_depth = 1; // Maximum cache depth (support N forwards before backward) - MOESFTConfig() : GeneralMOEConfig() {} MOESFTConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) diff --git a/kt-kernel/python/utils/amx_sft.py b/kt-kernel/python/utils/amx_sft.py index c4e48ed1..89b5cf3b 100644 --- a/kt-kernel/python/utils/amx_sft.py +++ b/kt-kernel/python/utils/amx_sft.py @@ -302,7 +302,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): ) # Determine loader and base key format based on method - if self.method == "AMXBF16_SFT": + if "BF16" in self.method: # BF16 mode: Load from HuggingFace model path loader = BF16SafeTensorLoader(self.weight_path) base_key = f"model.layers.{self.layer_idx}" @@ -323,7 +323,7 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): # Stack expert weights: [num_experts, ...] # For BF16: weights are already tensors # For SafeTensorLoader: weights might be numpy arrays in nested lists - if self.method == "AMXBF16_SFT": + if "BF16" in self.method: # BF16SafeTensorLoader returns list of tensors self.gate_proj = torch.stack(gate_weights, dim=0).contiguous() self.up_proj = torch.stack(up_weights, dim=0).contiguous()