diff --git a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index f006a6c527..ca84f4c5e8 100644 --- a/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/07_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -4,6 +4,32 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" #include +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -41,12 +67,12 @@ bool run(const ck_tile::ArgParser& arg_parser) using GammaDataType = DataType; using XDataType = DataType; using YScaleDataType = DataType; - using QYDataType = int8_t; + using QYDataType = ck_tile::int8_t; using ComputeDataType = float; // host verify - ck_tile::HostTensor a_host({m, n}, {stride, 1}); - ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); @@ -121,41 +147,95 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { + using YDataType = ComputeDataType; using InvRmsDataType = DataType; - ck_tile::HostTensor invRms_host_ref({m}); - // ck_tile::reference_rmsnorm2d_fwd( - // x_host, gamma_host, qy_host_ref, invRms_host_ref, epsilon); + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); - // qy_buf.FromDevice(qy_host_dev.data()); + x_buf.FromDevice(x_host_dev.data()); - // auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3); - // if(stride == n) - // { - // pass = ck_tile::check_err( - // qy_host_dev, qy_host_ref, std::string("OUT Error: Incorrect results!"), rtol, - // atol); - // } - // else - // { - // for(int i_r = 0; i_r < m; i_r++) - // { - // std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, - // qy_host_dev.begin() + i_r * stride + n); - // std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, - // qy_host_ref.begin() + i_r * stride + n); - // pass &= ck_tile::check_err(qy_host_dev_row, - // qy_host_ref_row, - // std::string("OUT[") + std::to_string(i_r) + - // std::string("] Error: Incorrect results!"), - // rtol, - // atol); - // } - // } + auto [rtol, atol] = get_elimit(); + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + + ck_tile::HostTensor y_host({m, n}); + // RmsNorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 8108b0f1ae..7e33cb3076 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -19,11 +19,13 @@ #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp new file mode 100644 index 0000000000..809049fa64 --- /dev/null +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { +template +CK_TILE_HOST void reference_unary_elementwise(const HostTensor& a, + HostTensor& b, + ElementOp element_op) +{ + // TODO: imeplement gpu version reference function + auto f = [&](auto i) { + auto v_a = type_convert(a.mData[i]); + auto v_b = element_op(v_a); + b.mData[i] = ck_tile::type_convert(v_b); + }; + + make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency()); +} + +template +CK_TILE_HOST void reference_binary_elementwise(const HostTensor& a, + const HostTensor& b, + HostTensor& c, + ElementOp element_op) +{ + // TODO: imeplement gpu version reference function + auto f = [&](auto i) { + auto v_a = type_convert(a.mData[i]); + auto v_b = type_convert(b.mData[i]); + auto v_c = element_op(v_a, v_b); + c.mData[i] = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp new file mode 100644 index 0000000000..e9a398876f --- /dev/null +++ b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { +template +CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor& x_m_n, + const HostTensor& scale_m, + HostTensor& qx_m_n) +{ + auto f = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + for(int n = 0; n < N; ++n) + { + auto v_x = x_m_n(m, n); + // scale = amax / 127 for int8 + auto v_scale = type_convert(scale_m(m)); + auto v_qx = v_x / v_scale; + qx_m_n(m, n) = saturates{}(v_qx); + } + }; + + make_ParallelTensorFunctor(f, + scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile