mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
@@ -3,8 +3,7 @@ set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
rocm_install(TARGETS ${TILE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples)
|
||||
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
@@ -16,8 +15,7 @@ list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-t
|
||||
target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
|
||||
|
||||
set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd")
|
||||
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
rocm_install(TARGETS ${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples)
|
||||
add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -67,14 +67,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
@@ -89,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
|
||||
@@ -193,9 +191,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
21
example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp
Executable file → Normal file
21
example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp
Executable file → Normal file
@@ -62,14 +62,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
assert(stride >= n);
|
||||
|
||||
using ADataType = DataType;
|
||||
using BDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using XDataType = DataType;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
using ADataType = DataType;
|
||||
using BDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using XDataType = DataType;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
@@ -82,7 +81,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
|
||||
@@ -195,9 +193,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
Reference in New Issue
Block a user