// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include "ck_tile/host.hpp" #include "ck_tile/utility/json_dump.hpp" #include "fused_moe.hpp" // different threshold for different dtype template auto get_elimit() { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } template <> auto get_elimit() { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } // mfma_type, 0:32x32, 1:16x16 // TODO: padding? template auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) { assert(t.get_lengths().size() == 3); int b_ = t.get_lengths()[0]; int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[2]; if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) { ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 16, 2, 8}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); } else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) { ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 32, 4, 8}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); } else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) { ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 32, 2, 16}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); } else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) { ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 64, 4, 16}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); } return t; } template void topid_unique_gen( std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) { size_t total_size = topk * tokens; std::srand(seed); std::set unique_set; IndexType current_v; for(size_t i = 0; i < total_size; i++) { if(i % topk == 0) { unique_set.clear(); } current_v = std::rand() % num_expert; while(unique_set.find(current_v) != unique_set.end()) { current_v = std::rand() % num_expert; } unique_set.insert(current_v); host_tensor[i] = current_v; } } auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser .insert("t", "128", "number of input tokens.\n" "If \"local_t\" presents, this value indicates global concurrency of all ranks.") .insert( "local_t", "-1", "Number of local input tokens for curent rank.\n" "This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n" "This feature is to simulate EP case where where each rank has different tokens.\n" "Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.") .insert("e", "32", "num of experts") .insert("k", "5", "topk") .insert("h", "8192", "hidden_size of this model") .insert("i", "8192", "intermediate_size between 2 gemms of FFN") .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") .insert("bm", "32", "blocking factor for sorted tokens") .insert("tp", "8", "tensor parallel size") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec_i", "bf16", "input precision") .insert("prec_w", "bf16", "weight precision") .insert("prec_o", "bf16", "output precision") .insert("prec_st", "auto", "token scale data type. auto will set to fp32") .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert( "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") .insert("act", "0", "activation after first gemm. 0:gelu, 1:silu") .insert("balance", "0", "if set to 1, will try balance the expert in topk-ids(convenient for testing)") .insert("init", "1", "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand " "normalized[0, 1]" "normalized(slow)") .insert("seed", "11939", "seed used to do random") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "fused_moe.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } // I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, // SQ:smooth-quant-type, KW:topk-weight-type template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t tokens = arg_parser.get_int("t"); ck_tile::index_t local_tokens = arg_parser.get_int("local_t"); ck_tile::index_t experts = arg_parser.get_int("e"); ck_tile::index_t topk = arg_parser.get_int("k"); ck_tile::index_t hidden_size = arg_parser.get_int("h"); ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t block_m = arg_parser.get_int("bm"); ck_tile::index_t activation = arg_parser.get_int("act"); if(stride < 0) stride = hidden_size; std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_st = arg_parser.get_str("prec_st"); std::string prec_sw = arg_parser.get_str("prec_sw"); std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_kw = arg_parser.get_str("prec_kw"); prec_st = (prec_st == "auto") ? "fp32" : prec_st; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); int fused_quant = arg_parser.get_int("fquant"); int gate_only = arg_parser.get_int("gate_only"); int api = arg_parser.get_int("api"); int balance = arg_parser.get_int("balance"); int tp = arg_parser.get_int("tp"); int init = arg_parser.get_int("init"); uint32_t seed = arg_parser.get_uint32("seed"); bool local_expert_masking = false; // TODO... // w0 (Gate+Up or Gate only, N size) ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; // w1 (Down, N size) ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp; bool is_local_token = local_tokens >= 0 && local_tokens < tokens; if(local_tokens > tokens) { printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens); return false; } auto prec_str = [&]() { auto base_str = prec_i; if(prec_i != prec_w) base_str += "x" + prec_w; if(prec_i != prec_o) base_str += "=" + prec_o; if(fused_quant != 0) { base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")"; } return base_str; }(); auto api_str = [&]() { if(api == 0) return std::string("fmoe"); else if(api == 1) return std::string("moeg"); else if(api == 2) return std::string("moes"); return std::string(""); }(); auto stride_str = [&]() { if(stride == hidden_size) return std::string(""); else return std::string(", st:") + std::to_string(stride); }(); std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens; if(is_local_token) { std::cout << "(" << local_tokens << ")"; } std::cout << ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp << ", act:" << activation // << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 << (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush; using TypeConfig = FusedMoeGemmTypeConfig; using ADataType = typename TypeConfig::ADataType; using GDataType = typename TypeConfig::GDataType; using DDataType = typename TypeConfig::DDataType; using AccDataType = typename TypeConfig::AccDataType; using ODataType = typename TypeConfig::ODataType; using AScaleDataType = typename TypeConfig::AScaleDataType; using GScaleDataType = typename TypeConfig::GScaleDataType; using DScaleDataType = typename TypeConfig::DScaleDataType; using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; using TopkWeightDataType = typename TypeConfig::TopkWeightDataType; using IndexDataType = typename TypeConfig::IndexDataType; // host verify ck_tile::HostTensor a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor g_host({experts, shared_intermediate_size_0, hidden_size}); ck_tile::HostTensor d_host({experts, hidden_size, shared_intermediate_size_1}); ck_tile::HostTensor o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor sa_host({tokens}); ck_tile::HostTensor sg_host({shared_intermediate_size_0}); ck_tile::HostTensor sd_host({shared_intermediate_size_1}); ck_tile::HostTensor sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort ck_tile::HostTensor local_expert_mask_host({experts}); int max_num_tokens_padded = topk * tokens + experts * block_m - topk; ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); ck_tile::HostTensor sorted_weight_host({max_num_tokens_padded}); ck_tile::HostTensor sorted_expert_ids_host( {(max_num_tokens_padded + block_m - 1) / block_m}); ck_tile::HostTensor num_sorted_tiles_host({1}); if(init == 0) { ck_tile::FillStepRange{-.5f, .5f, 0.01f}(a_host); ck_tile::FillStepRange{-.5f, .5f, 0.01f}(g_host); ck_tile::FillStepRange{.5f, -.5f, -0.01f}(d_host); ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sa_host); ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sg_host); ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sd_host); ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sy_host); ck_tile::FillStepRange{-.5f, .5f, 0.01f}(topk_weight_host); } else if(init == 1) { ck_tile::FillUniformDistribution{-.5f, .5f, seed}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(g_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(d_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sa_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sg_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sd_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(sy_host); ck_tile::FillUniformDistribution{-.5f, .5f, seed}(topk_weight_host); } else if(init == 2) { ck_tile::FillNormalDistribution{0.f, 1.f, seed}(a_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(g_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(d_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sa_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sg_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sd_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(sy_host); ck_tile::FillNormalDistribution{0.f, 1.f, seed}(topk_weight_host); } // permute weight ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); // do moe sorting if(balance) { int e_cnt = 0; for(int i = 0; i < static_cast(topk_ids_host.mData.size()); i++) { topk_ids_host.mData[i] = e_cnt; e_cnt++; if(e_cnt >= experts) e_cnt = 0; } } else { topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11913); } // leave it here for future debug purpose #if 0 a_host.loadtxt("../../ater/input_torch.txt"); topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int"); // topk_ids_host.savetxt("topk_ids_2.txt"); topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float"); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; g_host.loadtxt("../../ater/w1_torch.txt", "float"); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; d_host.loadtxt("../../ater/w2_torch.txt", "float"); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; #endif #if 0 std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl; std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl; std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl; std::cout << "topk_weight_host:" << topk_weight_host << std::endl; std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl; #endif auto cal_tflops = [&](auto ms) { double flop_gemm_0 = 2 * static_cast(tokens) * topk * shared_intermediate_size_0 * hidden_size; double flop_gemm_1 = 2 * static_cast(tokens) * topk * shared_intermediate_size_1 * hidden_size; return (flop_gemm_0 + flop_gemm_1) / (static_cast(ms) * 1e-3) / 1e12; }; // TODO: this method we use expert-by-expert view, just for reference auto cal_tbps = [&](auto ms) { double token_bytes = static_cast(tokens) * topk / experts * hidden_size * sizeof(ADataType); double w0_bytes = static_cast(shared_intermediate_size_0) * experts * hidden_size * sizeof(GDataType); double w1_bytes = static_cast(shared_intermediate_size_1) * experts * hidden_size * sizeof(DDataType); double o_bytes = static_cast(tokens) * topk / experts * hidden_size * sizeof(ODataType); double topk_weights_bytes = static_cast(tokens) * topk * sizeof(TopkWeightDataType); // ignore index, they are too small return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) / (static_cast(ms) * 1e-3) / 1e12; }; if(api == 0) { ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem g_perm_buf(g_perm_host); ck_tile::DeviceMem d_perm_buf(d_perm_host); ck_tile::DeviceMem sa_buf(sa_host); ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host); ck_tile::DeviceMem topk_weight_buf(topk_weight_host); ck_tile::DeviceMem sorted_token_ids_buf( sorted_token_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_expert_ids_buf( sorted_expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem num_sorted_tiles_buf( num_sorted_tiles_host.get_element_space_size_in_bytes()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); if(is_local_token) { local_tokens_dev.ToDevice(&local_tokens); } fused_moe_traits traits{prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_kw, block_m, activation, gate_only, fused_quant, local_expert_masking}; fused_moe_args args{a_buf.GetDeviceBuffer(), fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, g_perm_buf.GetDeviceBuffer(), d_perm_buf.GetDeviceBuffer(), fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() : nullptr, is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, o_buf.GetDeviceBuffer(), workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, topk_ids_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(), sorted_weight_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(), block_m, hidden_size, intermediate_size / tp, tokens, experts, topk, stride}; float ave_time = fused_moe( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); if(ave_time < 0) { std::cout << " not supported!" << std::endl << std::flush; return false; } // float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " << cal_tbps(ave_time) << " TB/s" << std::flush; bool pass = true; #define CPU_FUSED_MOE(act_type_) \ ck_tile::reference_fused_moe(a_host, \ g_host, \ d_host, \ sa_host, \ sg_host, \ sd_host, \ sy_host, \ o_host, \ sorted_token_ids_host, \ sorted_weight_host, \ sorted_expert_ids_host, \ num_sorted_tiles_host, \ topk_ids_host, \ block_m, \ tokens, \ experts, \ hidden_size, \ intermediate_size / tp, \ topk, \ gate_only) if(do_validation) { ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, block_m, is_local_token ? local_tokens : tokens, local_expert_masking); if(activation == 0) { CPU_FUSED_MOE(ck_tile::element_wise::Gelu); } else { CPU_FUSED_MOE(ck_tile::element_wise::Silu); } auto o_dev = o_buf.ToHost(); // o_dev.savetxt("gpu-out.txt", "float"); auto [rtol, atol] = get_elimit(); pass &= ck_tile::check_err( o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; } std::cout << std::flush << std::endl; if(arg_parser.get_int("json") == 1) { dump_fused_moe_json(arg_parser.get_str("jsonfile"), api_str, prec_str, tokens, is_local_token, local_tokens, experts, topk, hidden_size, intermediate_size, stride, block_m, activation, gate_only, fused_quant, pass, ave_time, cal_tflops(ave_time), cal_tbps(ave_time)); } return pass; } else if(api == 1) { ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, block_m, is_local_token ? local_tokens : tokens, local_expert_masking); // done, preparing GPU buffer ck_tile::DeviceMem a_buf(a_host); ck_tile::DeviceMem g_perm_buf(g_perm_host); ck_tile::DeviceMem d_perm_buf(d_perm_host); ck_tile::DeviceMem sa_buf(sa_host); ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem o_buf(o_host); ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t)); if(is_local_token) { local_tokens_dev.ToDevice(&local_tokens); } // manually clear output buffer for atomic o_buf.SetZero(); // ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host); ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); fused_moegemm_traits traits{prec_i, prec_w, prec_o, prec_st, prec_sw, prec_sq, prec_kw, block_m, activation, gate_only, fused_quant}; fused_moegemm_args args{a_buf.GetDeviceBuffer(), fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, g_perm_buf.GetDeviceBuffer(), d_perm_buf.GetDeviceBuffer(), fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, o_buf.GetDeviceBuffer(), sorted_token_ids_buf.GetDeviceBuffer(), sorted_weight_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(), hidden_size, intermediate_size / tp, is_local_token ? local_tokens : tokens, experts, topk, stride}; float ave_time = fused_moegemm( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); if(ave_time < 0) { std::cout << " not supported!" << std::endl << std::flush; return false; } // float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " << cal_tbps(ave_time) << " TB/s" << std::flush; bool pass = true; if(do_validation) { if(activation == 0) { CPU_FUSED_MOE(ck_tile::element_wise::Gelu); } else { CPU_FUSED_MOE(ck_tile::element_wise::Silu); } auto o_dev = o_buf.ToHost(); // o_dev.savetxt("gpu-out.txt", "float"); auto [rtol, atol] = get_elimit(); pass &= ck_tile::check_err( o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; } std::cout << std::flush << std::endl; if(arg_parser.get_int("json") == 1) { dump_fused_moe_json(arg_parser.get_str("jsonfile"), api_str, prec_str, tokens, is_local_token, local_tokens, experts, topk, hidden_size, intermediate_size, stride, block_m, activation, gate_only, fused_quant, pass, ave_time, cal_tflops(ave_time), cal_tbps(ave_time)); } return pass; } return false; } int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_w = arg_parser.get_str("prec_w"); std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_st = arg_parser.get_str("prec_st"); std::string prec_sw = arg_parser.get_str("prec_sw"); std::string prec_sq = arg_parser.get_str("prec_sq"); std::string prec_kw = arg_parser.get_str("prec_kw"); prec_st = (prec_st == "auto") ? "fp32" : prec_st; prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; // no dynamic quant case if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") { return run( arg_parser) ? 0 : -2; } else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32") { return run( arg_parser) ? 0 : -2; } return -3; }