diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 43bc9a6cfe..574edf64d3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -3,7 +3,7 @@ #include // different threshold for different dtype -template +template auto get_elimit() { double rtol = 1e-2; @@ -39,6 +39,7 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") + .insert("quant", "int8", "precision") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -46,7 +47,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -54,16 +55,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t stride = arg_parser.get_int("stride"); if(stride < 0) stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + float epsilon = arg_parser.get_float("e"); + std::string input_data_type = arg_parser.get_str("prec"); + std::string quantized_data_type = arg_parser.get_str("quant"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); assert(stride >= n); - using TypeConfig = AddRmsnormRdquantTypeConfig; + using TypeConfig = AddRmsnormRdquantTypeConfig; using ADataType = typename TypeConfig::ADataType; using BDataType = typename TypeConfig::BDataType; @@ -102,10 +104,10 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << data_type << "]" + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; - add_rmsnorm2d_rdquant_fwd_traits traits{data_type, SaveX}; + add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), b_buf.GetDeviceBuffer(), @@ -129,14 +131,14 @@ bool run(const ck_tile::ArgParser& arg_parser) num_byte += sizeof(XDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl; bool pass = true; if(do_validation) { using YDataType = ComputeDataType; - using InvRmsDataType = DataType; + using InvRmsDataType = InputDataType; // Add { @@ -144,28 +146,36 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_binary_elementwise( a_host, b_host, x_host_ref, op); - x_buf.FromDevice(x_host_dev.data()); + if constexpr(SaveX) + { + x_buf.FromDevice(x_host_dev.data()); - auto [rtol, atol] = get_elimit(); - if(stride == n) - { - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); - } - else - { - for(int i_r = 0; i_r < m; i_r++) + auto [rtol, atol] = get_elimit(); + if(stride == n) { - std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, - x_host_dev.begin() + i_r * stride + n); - std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, - x_host_ref.begin() + i_r * stride + n); - pass &= ck_tile::check_err(x_host_dev_row, - x_host_ref_row, - std::string("x[") + std::to_string(i_r) + - std::string("] Error: Incorrect results!"), - rtol, - atol); + pass = ck_tile::check_err(x_host_dev, + x_host_ref, + std::string("x Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } } @@ -256,23 +266,40 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_x = arg_parser.get_int("save_x"); - if(data_type == "fp16" && save_x) + const std::string input_data_type = arg_parser.get_str("prec"); + const std::string quantized_data_type = arg_parser.get_str("quant"); + int save_x = arg_parser.get_int("save_x"); + if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "fp16" && !save_x) + else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && save_x) + else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && !save_x) + else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index 443b9b1024..c91b387d62 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -8,11 +8,11 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" #include -template +template struct AddRmsnormRdquantTypeConfig; template <> -struct AddRmsnormRdquantTypeConfig +struct AddRmsnormRdquantTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -24,7 +24,7 @@ struct AddRmsnormRdquantTypeConfig }; template <> -struct AddRmsnormRdquantTypeConfig +struct AddRmsnormRdquantTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; @@ -35,13 +35,38 @@ struct AddRmsnormRdquantTypeConfig using ComputeDataType = float; }; +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + // runtime args struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel -template struct add_rmsnorm2d_rdquant_fwd_traits_ { - using DataType = ck_tile::remove_cvref_t; + using InputDataType = ck_tile::remove_cvref_t; + using QuantizedDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); @@ -114,7 +140,8 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_ // This is the public API, will be generated by script struct add_rmsnorm2d_rdquant_fwd_traits { - std::string data_type; + std::string input_data_type; + std::string quantized_data_type; bool save_x; }; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp index 966c5bd02f..bf8c5c7553 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -4,7 +4,8 @@ #include #include "add_rmsnorm2d_rdquant_fwd.hpp" -template -using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; -template -float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t, add_rmsnorm2d_rdquant_fwd_args a, const ck_tile::stream_config& s) { @@ -32,99 +34,133 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, // clang-format off // rm rn tm tn vn pd x 3p if(a.n <= 64) { - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 128) { if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 256) { if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 512) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 768) { if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 1024) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 1536) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 2048) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 3072) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 4096) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } - else if(a.n > 4096) { + else if(a.n <= 8192) { + if(a.n<8192){ + if(t.save_x){ if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + else if(a.n > 8192) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } return r; // clang-format on @@ -134,16 +170,45 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, add_rmsnorm2d_rdquant_fwd_args a, const ck_tile::stream_config& s) { - - // Only support instance of save_x == true for now - assert(t.save_x); - if(t.data_type.compare("fp16") == 0) + if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) { - return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); } - else if(t.data_type.compare("bf16") == 0) + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) { - return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); } else throw std::runtime_error("Without supported instances!"); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp index 5495e3c9ab..00df2f5082 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); #endif -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp index 8bbfdc8589..2adb54c078 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp index 381a11fc80..39089843a2 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp index 2fefac6934..ddb8e1b354 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp index 263713bbc7..2a87614403 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp index c62c596fab..045a3b8880 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index e4951f6ab9..0000000000 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp index 4c7ee48e8e..1028973e74 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp index 8659dc82b3..b8439a0ce9 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp index 5f15f11b47..b24b245757 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp new file mode 100644 index 0000000000..e09f081812 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp @@ -0,0 +1,34 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..3e3a6d75b9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp index 8ffdacbdcd..04d735c12c 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); #endif -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp index 3551099651..5893d6c3ee 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp index d4d0474c27..ec9c417bf3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp index 2cb300eda6..5bc8245106 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp index fb0ceb4c58..c022c62de6 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp index 3a241a3c93..19172b0793 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index d3094679f9..0000000000 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp index 919bc177e8..f491d92787 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp index 8a44f5e00f..065f0ea4cc 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -5,8 +5,11 @@ #include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" // clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp index 5c4f05ec3c..be8c6c4de5 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp new file mode 100644 index 0000000000..f60a5c89fa --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp @@ -0,0 +1,33 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..e3afa07fa4 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 6baaad471a..25b10e1dc4 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -11,7 +11,8 @@ using S = ck_tile::stream_config; using A = add_rmsnorm2d_rdquant_fwd_args; -template -using trait_ = add_rmsnorm2d_rdquant_fwd_traits_ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) { - using DataType = typename Traits_::DataType; + using InputDataType = typename Traits_::InputDataType; + using QuantizedDataType = typename Traits_::QuantizedDataType; using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< - typename AddRmsnormRdquantTypeConfig::ADataType, - typename AddRmsnormRdquantTypeConfig::BDataType, - typename AddRmsnormRdquantTypeConfig::GammaDataType, - typename AddRmsnormRdquantTypeConfig::ComputeDataType, - typename AddRmsnormRdquantTypeConfig::XDataType, - typename AddRmsnormRdquantTypeConfig::YScaleDataType, - typename AddRmsnormRdquantTypeConfig::QYDataType, + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, typename Traits_::Shape, Traits_::kPadN, Traits_::kSaveX, diff --git a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp index e9a398876f..aff5e78ff0 100644 --- a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp +++ b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp @@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor& // scale = amax / 127 for int8 auto v_scale = type_convert(scale_m(m)); auto v_qx = v_x / v_scale; - qx_m_n(m, n) = saturates{}(v_qx); + qx_m_n(m, n) = type_convert(saturates{}(v_qx)); } }; diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index 24f35d3636..64e5224780 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass sweep_tile(qy, [&, yscale_ = yscale](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); auto qy_ = y[idx] / yscale_[i_idx]; - qy(idx) = saturates{}(qy_); + qy(idx) = type_convert(saturates{}(qy_)); }); store_tile(qy_window, qy); } diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index aec7368e27..ecd4e81b22 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass const auto x_ = type_convert(x[idx]); auto y_ = x_ * inv_rms[i_idx] * gamma_; auto qy_ = y_ / yscale[i_idx]; - qy(idx) = saturates{}(qy_); + qy(idx) = type_convert(saturates{}(qy_)); }); store_tile(qy_window, qy);