From f2d28e8ab43ac67b033a11e537d91bc5aeb4a3bf Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 9 Jul 2024 05:22:08 +0000 Subject: [PATCH] Add reference_rotary_position_embedding() (not implemented) --- include/ck_tile/host.hpp | 1 + .../reference_rotary_position_embedding.hpp | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 include/ck_tile/host/reference/reference_rotary_position_embedding.hpp diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0e69a925d5..cab807a4d2 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,6 +20,7 @@ #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" +#include "ck_tile/host/reference/reference_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp new file mode 100644 index 0000000000..712226a8c0 --- /dev/null +++ b/include/ck_tile/host/reference/reference_rotary_position_embedding.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include + +namespace ck_tile { + +namespace detail { + +} + +template +CK_TILE_HOST void reference_rotary_position_embedding(const HostTensor& input_bhsd, + const HostTensor& cos_sd, + const HostTensor& sin_sd, + bool interleaved, + HostTensor& output_bhsd) +{ + assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2); + assert(cos_sd.get_length(0) == sin_sd.get_length(0) && + cos_sd.get_length(1) == sin_sd.get_length(1)); + + const index_t rotary_dim = cos_sd.get_length(1) * 2; + assert(rotary_dim <= input_bhsd.get_length(3)); + (void)rotary_dim; + (void)input_bhsd; + (void)sin_sd; + (void)cos_sd; + (void)interleaved; + (void)output_bhsd; +} + +} // namespace ck_tile