diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 612ef62f89..43bc9a6cfe 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[]) .insert("n", "4096", "n dimension") .insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("e", "1e-5", "epsilon") - .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") @@ -257,20 +257,20 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - int save_rms = arg_parser.get_int("save_rms"); - if(data_type == "fp16" && save_rms) + int save_x = arg_parser.get_int("save_x"); + if(data_type == "fp16" && save_x) { return run(arg_parser) ? 0 : -2; } - else if(data_type == "fp16" && !save_rms) + else if(data_type == "fp16" && !save_x) { return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && save_rms) + else if(data_type == "bf16" && save_x) { return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && !save_rms) + else if(data_type == "bf16" && !save_x) { return run(arg_parser) ? 0 : -2; } diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index 249191b63f..bf70d9d23f 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -115,7 +115,7 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_ struct add_rmsnorm2d_rdquant_fwd_traits { std::string data_type; - bool save_rms; + bool save_x; }; float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits, diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp index 5ba0fac76e..57a0f254d0 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -140,6 +140,8 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, { float r = -1; + // Only support instance of save_x == true for now + assert(t.save_x); if(t.data_type.compare("fp16") == 0) { return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s);