mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Add save_x to trait
This commit is contained in:
@@ -35,7 +35,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
@@ -257,20 +257,20 @@ int main(int argc, char* argv[])
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
if(data_type == "fp16" && save_rms)
|
||||
int save_x = arg_parser.get_int("save_x");
|
||||
if(data_type == "fp16" && save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp16" && !save_rms)
|
||||
else if(data_type == "fp16" && !save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16" && save_rms)
|
||||
else if(data_type == "bf16" && save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16" && !save_rms)
|
||||
else if(data_type == "bf16" && !save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_
|
||||
struct add_rmsnorm2d_rdquant_fwd_traits
|
||||
{
|
||||
std::string data_type;
|
||||
bool save_rms;
|
||||
bool save_x;
|
||||
};
|
||||
|
||||
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits,
|
||||
|
||||
@@ -140,6 +140,8 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
|
||||
{
|
||||
|
||||
float r = -1;
|
||||
// Only support instance of save_x == true for now
|
||||
assert(t.save_x);
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t>(t, a, s);
|
||||
|
||||
Reference in New Issue
Block a user