mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
Fix moe bug. (#1783)
* [fix]: fix moe.hpp load from file bug. * [fix]: fix all moe hpp init bug. * [fix]: fix moe & awq-moe ug.
This commit is contained in:
@@ -42,8 +42,6 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
std::filesystem::path prefix;
|
||||
|
||||
#ifdef CHECK
|
||||
char verify_bb[100000000];
|
||||
char check_bb[100000000];
|
||||
@@ -406,7 +404,7 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
|
||||
auto& load = config_.load;
|
||||
auto& save = config_.save;
|
||||
|
||||
prefix = config_.path;
|
||||
std::filesystem::path prefix = config_.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
|
||||
if (save) {
|
||||
std::cout << "Creating " << prefix << std::endl;
|
||||
@@ -498,6 +496,9 @@ class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
|
||||
throw std::runtime_error("AMX load weights from gate_projs is not supported");
|
||||
} else {
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
std::filesystem::path prefix = config_.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
|
||||
|
||||
if (config_.load) {
|
||||
throw std::runtime_error("AMX load weights from file is not supported");
|
||||
}
|
||||
|
||||
@@ -32,13 +32,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
|
||||
// quantized)]
|
||||
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
|
||||
// quantized)]
|
||||
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
|
||||
// quantized)]
|
||||
|
||||
#ifdef CHECK
|
||||
char verify_bb[100000000];
|
||||
char check_bb[100000000];
|
||||
@@ -159,10 +152,6 @@ class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
|
||||
throw std::runtime_error("Path not found: " + prefix.string());
|
||||
}
|
||||
}
|
||||
|
||||
gate_proj_ = config_.gate_proj;
|
||||
up_proj_ = config_.up_proj;
|
||||
down_proj_ = config_.down_proj;
|
||||
}
|
||||
|
||||
~AMX_MOE_TP() = default;
|
||||
|
||||
Reference in New Issue
Block a user