Add save_x to trait

This commit is contained in:
rocking
2024-10-28 19:55:44 +00:00
parent b683de6b32
commit 6a54faae25
3 changed files with 9 additions and 7 deletions

View File

@@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
.insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
@@ -257,20 +257,20 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
int save_rms = arg_parser.get_int("save_rms");
if(data_type == "fp16" && save_rms)
int save_x = arg_parser.get_int("save_x");
if(data_type == "fp16" && save_x)
{
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16" && !save_rms)
else if(data_type == "fp16" && !save_x)
{
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && save_rms)
else if(data_type == "bf16" && save_x)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16" && !save_rms)
else if(data_type == "bf16" && !save_x)
{
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
}

View File

@@ -115,7 +115,7 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_
struct add_rmsnorm2d_rdquant_fwd_traits
{
std::string data_type;
bool save_rms;
bool save_x;
};
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits,

View File

@@ -140,6 +140,8 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
{
float r = -1;
// Only support instance of save_x == true for now
assert(t.save_x);
if(t.data_type.compare("fp16") == 0)
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t>(t, a, s);