diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt new file mode 100644 index 0000000000..b43b989792 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp) +target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) diff --git a/example/ck_tile/09_topk_softmax/README.md b/example/ck_tile/09_topk_softmax/README.md new file mode 100644 index 0000000000..1043012900 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/README.md @@ -0,0 +1,28 @@ +# topk-softmax + +This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_topk_softmax -j +``` +This will result in an executable `build/bin/tile_example_topk_softmax` + +## example +``` +args: + -v weather do CPU validation or not (default:1) + -pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16) + -pr_w output weight data type(currently only fp32 supported now) (default:fp32) + -t number of input tokens (default:32) + -e number of experts (default:8) + -k topk (default:2) + -st_i row stride of input, -1 means same as experts (default:-1) + -st_o row stride of output/indices, -1 means same as topk (default:-1) + -seed seed to be used, -1 means random every time (default:-1) + -kname when set to 1 it will print kernel name (default:0) + +``` diff --git a/example/ck_tile/09_topk_softmax/script/smoke_test.sh b/example/ck_tile/09_topk_softmax/script/smoke_test.sh new file mode 100644 index 0000000000..646f5889f7 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/script/smoke_test.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_topk_softmax + +for pr_i in "fp16" "bf16" ; do +$EXE -pr_i=$pr_i -t=80 -e=17 +$EXE -pr_i=$pr_i -t=111 -e=117 +$EXE -pr_i=$pr_i -t=1000 -e=55 +$EXE -pr_i=$pr_i -t=99 -e=180 +$EXE -pr_i=$pr_i -t=175 -e=64 -k=8 +$EXE -pr_i=$pr_i -t=65 -e=8 -k=2 +$EXE -pr_i=$pr_i -t=1 -e=25 +$EXE -pr_i=$pr_i -t=31 -e=19 -k=15 +$EXE -pr_i=$pr_i -t=81 -e=37 -k=7 +$EXE -pr_i=$pr_i -t=199 -e=128 -k=13 +$EXE -pr_i=$pr_i -t=23 -e=1 -k=1 +$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31 +$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12 +$EXE -pr_i=$pr_i -t=1 -e=1 -k=1 +$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5 +$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17 +done diff --git a/example/ck_tile/09_topk_softmax/topk_softmax.cpp b/example/ck_tile/09_topk_softmax/topk_softmax.cpp new file mode 100644 index 0000000000..6fc25631fd --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "topk_softmax_api.hpp" + +#if 0 +template +void dump_host_tensor_2d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 2); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + if constexpr(std::is_same_v) + { + auto v = ck_tile::type_convert(x(i, j)); + + std::cout << v; + if(j != len[1] - 1) + std::cout << ","; + } + else + { + std::cout << x(i, j) << " "; + } + } + std::cout << "]"; + if(i != len[0] - 1) + std::cout << ","; + else + std::cout << "]"; + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// CPU reference +template +auto reference_topk_softmax(const ck_tile::HostTensor& x, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + auto y = reference_softmax(x, dim); + + auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted); + + return ck_tile::make_tuple(y_values, y_indices); +} + +template +auto reference_topk_softmax(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + auto y = reference_softmax(x, dim); + reference_topk(y, y_values, y_indices, k, dim, largest, sorted); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") + .insert("t", "32", "number of input tokens") + .insert("e", "8", "number of experts") + .insert("k", "2", "topk") + .insert("st_i", "-1", "row stride of input, -1 means same as experts") + .insert("st_o", "-1", "row stride of output/indices, -1 means same as topk") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool test_topk_softmax(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string input_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int stride_input = args.get_int("st_i"); + int stride_output = args.get_int("st_o"); + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + + if(stride_input < 0) + { + stride_input = experts; + } + if(stride_output < 0) + { + stride_output = topk; + } + assert(stride_input >= experts); + assert(stride_output >= topk); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > experts) + { + printf("topk:%d value should be smaller than, or equal to number of experts:%d\n", + topk, + experts); + return false; + } + + // tokens already considered batch size + ck_tile::HostTensor x_host({tokens, experts}, {stride_input, 1}); + ck_tile::HostTensor value_host({tokens, topk}, {stride_output, 1}); + ck_tile::HostTensor index_host({tokens, topk}, {stride_output, 1}); + + { + // random require per-row unique + auto rand_gen = ck_tile::FillUniformDistribution_Unique{ + -5.f, 5.f, static_cast(seed)}; + + for(int i_t = 0; i_t < tokens; i_t++) + { + ck_tile::HostTensor x_row({experts}); + rand_gen(x_row); + std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input); + rand_gen.clear(); + } + } + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + topk_softmax_trait trait{input_prec, weight_prec, experts}; + + topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), + value_dev.GetDeviceBuffer(), + index_dev.GetDeviceBuffer(), + tokens, + experts, + topk, + stride_input, + stride_output}; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + auto ms = topk_softmax(trait, karg, sc); + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + input_prec.c_str(), + weight_prec.c_str(), + tokens, + experts, + topk, + stride_input, + stride_output, + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + if(ms < 0) + { + return false; + } + + value_dev.FromDevice(value_host.data()); + index_dev.FromDevice(index_host.data()); + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); + ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); + + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + + auto [rtol, atol] = get_elimit(""); + for(int i_t = 0; i_t < tokens; i_t++) + { + auto s_begin = std::vector{static_cast(i_t), static_cast(0)}; + auto s_end = + std::vector{static_cast(i_t + 1), static_cast(topk)}; + auto s_value_host = value_host.slice(s_begin, s_end); + auto s_value_ref = value_ref.slice(s_begin, s_end); + rtn &= ck_tile::check_err(s_value_host, + s_value_ref, + std::string("[") + std::to_string(i_t) + + std::string("] Value Error:"), + rtol, + atol); + auto s_index_host = index_host.slice(s_begin, s_end); + auto s_index_ref = index_ref.slice(s_begin, s_end); + rtn &= ck_tile::check_err(s_index_host, + s_index_ref, + std::string("[") + std::to_string(i_t) + + std::string("] Index Error:"), + rtol, + atol); + } + } + + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string input_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0) + { + r &= test_topk_softmax(args); + } + else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0) + { + r &= test_topk_softmax(args); + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp new file mode 100644 index 0000000000..249a307b81 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "topk_softmax_api.hpp" + +#define TOPK_SOFTMAX_DISPATCH(experts_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + using ts_problem = ck_tile:: \ + TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + constexpr dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + \ + return ave_time; + +float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) +{ + if(t.input_type == "fp16" && t.weight_type == "fp32") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192) + } +#endif + } + return -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp new file mode 100644 index 0000000000..65651efa4d --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/topk_softmax.hpp" +#include + +struct topk_softmax_trait +{ + std::string input_type; + std::string weight_type; // currently always float + int experts; +}; + +struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs +{ +}; + +float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index ec4a175d35..366fb18a0f 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -7,3 +7,5 @@ add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) add_subdirectory(04_img2col) add_subdirectory(05_reduce) +add_subdirectory(09_topk_softmax) + diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index d96f14710b..56dfbd636b 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -49,6 +49,7 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp index 77a635611e..6591acddb9 100644 --- a/include/ck_tile/core/algorithm/space_filling_curve.hpp +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,8 +81,10 @@ struct space_filling_curve return get_step_between(number{}, number{}); } + // Do not use this function directly! + // TODO: can refactor into generic lambda in the future template - static CK_TILE_HOST_DEVICE constexpr Index get_index(number) + static CK_TILE_HOST_DEVICE constexpr Index _get_index(number) { #if 0 /* @@ -153,11 +155,11 @@ struct space_filling_curve return idx_md; } - // FIXME: rename this function + // FIXME: return tuple of number<>, which is compile time only variable template - static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number) + static CK_TILE_HOST_DEVICE constexpr auto get_index(number) { - constexpr auto idx = get_index(number{}); + constexpr auto idx = _get_index(number{}); return generate_tuple([&](auto i) { return number{}; }, number{}); } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7f488d1b71..3feede4d2e 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct smem_load_trait; + +template struct smem_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct smem_load_trait<4 , T> { using payload_t = float; }; +template struct smem_load_trait<2 , T> { using payload_t = float; }; +template struct smem_load_trait<1 , T> { using payload_t = float; }; + +// clang-format on +} // namespace impl + +// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :) +template +struct smem_load; + +template <> +struct smem_load<16> +{ + template + CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) + { + static_assert(sizeof(T) == 16); + using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t; + asm volatile("ds_read_b128 %0, %1 offset:%2" + : "=v"(reinterpret_cast(value)) // ! direct write + : "v"(v_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct smem_load<8> +{ + template + CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) + { + static_assert(sizeof(T) == 8); + using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t; + asm volatile("ds_read_b64 %0, %1 offset:%2" + : "=v"(reinterpret_cast(value)) // ! direct write + : "v"(v_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct smem_load<4> +{ + template + CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) + { + static_assert(sizeof(T) == 4); + using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t; + asm volatile("ds_read_b32 %0, %1 offset:%2" + : "=v"(reinterpret_cast(value)) // ! direct write + : "v"(v_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct smem_load<2> +{ + template + CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) + { + static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually + using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; + asm volatile("ds_read_u16 %0, %1 offset:%2" + : "=v"(reinterpret_cast(value)) // ! direct write + : "v"(v_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct smem_load<1> +{ + template + CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) + { + static_assert(sizeof(T) == 4); + using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; + asm volatile("ds_read_u8 %0, %1 offset:%2" + : "=v"(reinterpret_cast(value)) // ! direct write + : "v"(v_offset), "n"(i_offset) + : "memory"); + } +}; + // clang-format off namespace impl{ @@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); +// Direct loads from global to LDS. +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + template CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, int32x4_t rsrc, @@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, + index_t src_linear_addr_offset, index_t flag = 0, bool_constant = {}) { @@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, - 0, + src_linear_addr_offset, flag, bool_constant{}); } @@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, - 0, + src_linear_addr_offset, flag, bool_constant{}); } @@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, bool_constant{}); } +template +CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0, + index_t flag = 0, + bool_constant = {}) +{ + static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + + if constexpr(oob_conditional_check) + { + index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + smem, + sizeof(uint32_t), + v_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } + else + { + llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, + smem, + sizeof(uint32_t), + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + static_cast(coherence)); + } +} + template CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, @@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thr int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset, + index_t dst_linear_addr_offset, index_t is_valid_element = 1) { constexpr index_t bytes = sizeof(T) * N; @@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thr dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0, + dst_linear_addr_offset, is_valid_element); } else @@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thr dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + dst_linear_addr_offset); } } @@ -2014,6 +2156,7 @@ template & dst, const T* p_src_wave, index_t src_thread_element_offset, + index_t src_linear_element_offset, index_t src_element_space_size, index_t is_valid_element = 0, bool_constant = {}) @@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); amd_buffer_load_raw_impl( dst, src_wave_buffer_resource, src_thread_addr_offset, 0, + src_linear_addr_offset, is_valid_element, bool_constant{}); } @@ -2041,16 +2186,19 @@ template & dst, const int32x4_t src_wave_buffer_resource, index_t src_thread_element_offset, + index_t src_linear_element_offset, index_t is_valid_element = 0, bool_constant = {}) { index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); amd_buffer_load_raw_impl( dst, src_wave_buffer_resource, src_thread_addr_offset, 0, + src_linear_addr_offset, is_valid_element, bool_constant{}); } @@ -2066,6 +2214,7 @@ template = {}) { @@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); + amd_async_buffer_load_impl(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + src_linear_addr_offset, + bool_constant{}); } // This version support buffer resource as input arg @@ -2086,12 +2240,42 @@ template = {}) { index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); + amd_async_buffer_load_impl(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + src_linear_addr_offset, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t src_linear_element_offset, + bool is_valid_element, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); + + amd_async_buffer_load(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + src_linear_addr_offset, + is_valid_element, + bool_constant{}); } // buffer_store requires: @@ -2146,6 +2330,7 @@ template & src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, + const index_t dst_linear_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { @@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_d make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); amd_buffer_store_raw_impl(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0, + dst_linear_addr_offset, dst_thread_element_valid); } @@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_ #endif } -// Direct loads from global to LDS. -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, - index_t size, - index_t voffset, - index_t soffset, - index_t offset, - index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); - template CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 580faae925..4be50b8656 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -41,6 +41,19 @@ #define CK_TILE_HOST_DEVICE_EXTERN #endif +// implementing the "memory address space" attribute +// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table +#ifdef __HIPCC_ +#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0))) +#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1))) +#define CK_TILE_LDS_ADDR __attribute__((address_space(3))) +#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8))) +#else +#define CK_TILE_GENERIC_ADDR +#define CK_TILE_GLOBAL_ADDR +#define CK_TILE_LDS_ADDR +#define CK_TILE_BUF_RES_ADDR +#endif #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #endif @@ -205,3 +218,8 @@ #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #endif + +// workaround: compiler not emiting reciprocal instruction frm __frcp_rn() +#ifndef CK_TILE_WORKAROUND_SWDEV_383542 +#define CK_TILE_WORKAROUND_SWDEV_383542 1 +#endif diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 598dfeea3e..19d853ad5c 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -623,7 +623,7 @@ template CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple& y, const X& x) { - static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + static_assert(X::size() == sizeof...(Ys), "wrong! size not the same"); constexpr index_t NSize = sizeof...(Ys); static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); return y; @@ -635,7 +635,7 @@ template CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple& y, const X& x) { - static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + static_assert(X::size() == sizeof...(Ys), "wrong! size not the same"); constexpr index_t NSize = sizeof...(Ys); static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); return y; @@ -647,7 +647,7 @@ template CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple& x, const Y& y) { - static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same"); constexpr index_t NSize = sizeof...(Xs); tuple r; @@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple& x, const Y& y) return r; } +template +CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple& x, const tuple& y) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!"); + constexpr index_t NSize = sizeof...(Xs); + return generate_tuple([&](auto i) { return x[i] + y[i]; }, number{}); +} + template ::value && !std::is_floating_point::value, bool> = false> CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple& x, const Y& y) { - static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same"); constexpr index_t NSize = sizeof...(Xs); tuple r; @@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple& x, const Y& y) return r; } +template +CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple& x, const tuple& y) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!"); + constexpr index_t NSize = sizeof...(Xs); + return generate_tuple([&](auto i) { return x[i] - y[i]; }, number{}); +} + template ::value && !std::is_floating_point::value, bool> = false> CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, const Y& y) { - static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same"); constexpr index_t NSize = sizeof...(Xs); tuple r; @@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, Y a) return a * x; } +template +CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, const tuple& y) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!"); + constexpr index_t NSize = sizeof...(Xs); + return generate_tuple([&](auto i) { return x[i] * y[i]; }, number{}); +} + template CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple& x, const tuple& y) { diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index f512e50e0a..785691b66f 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -487,55 +487,12 @@ struct log2e template constexpr T log2e_v = log2e::value; -// math -CK_TILE_HOST_DEVICE -float abs(const float& x) -{ - union - { - float f32; - uint32_t u32; - } y; - y.f32 = x; - y.u32 = y.u32 & 0x7fffffff; - return y.f32; -} - -CK_TILE_HOST_DEVICE -bool isnan(const float& x) -{ - uint32_t xx = bit_cast(x); - return (xx & 0x7fffffff) > 0x7F800000; -} - -CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); }; - -CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); }; - -CK_TILE_DEVICE -float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; - -CK_TILE_DEVICE -double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; - -CK_TILE_DEVICE -float exp(float x) { return __ocml_exp_f32(x); }; - -CK_TILE_HOST -float exp(float x) { return std::expf(x); } - CK_TILE_DEVICE float exp2(float x) { return exp2f(x); }; CK_TILE_HOST float exp2(float x) { return std::exp2f(x); }; -CK_TILE_DEVICE -float log(float x) { return __logf(x); }; - -CK_TILE_HOST -float log(float x) { return std::logf(x); }; - CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc) { return __builtin_amdgcn_sad_u16(x, y, acc); @@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc) return (x > y ? (x - y) : (y - x)) + acc; } +/////////////////////////////////////////////////////////////// + +} // namespace ck_tile +// blow function need data type pre-defined +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#ifndef __HIP_DEVICE_COMPILE__ +#include +#endif + +namespace ck_tile { +#if CK_TILE_WORKAROUND_SWDEV_383542 +extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float); +#endif + +// math functions for the host, some are implemented by calling C++ std functions + +CK_TILE_HOST float abs(float x) { return std::abs(x); }; + +CK_TILE_HOST double abs(double x) { return std::abs(x); }; + +CK_TILE_HOST int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +CK_TILE_HOST int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +CK_TILE_HOST fp16_t abs(fp16_t x) +{ + uint16_t xx = bit_cast(x); + + uint16_t abs_xx = xx & 0x7fff; + + fp16_t abs_x = bit_cast(abs_xx); + + return abs_x; +}; + +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +CK_TILE_HOST int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + return (x ^ sgn) - sgn; +} +#endif + +CK_TILE_HOST bool isnan(float x) { return std::isnan(x); }; + +CK_TILE_HOST bool isnan(double x) { return std::isnan(x); }; + +CK_TILE_HOST bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +CK_TILE_HOST bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +CK_TILE_HOST bool isnan(fp16_t x) +{ + uint16_t xx = bit_cast(x); + + return (xx & 0x7FFF) > 0x7C00; +}; + +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +CK_TILE_HOST bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + +CK_TILE_HOST fp16_t sqrt(fp16_t x) +{ + return static_cast(std::sqrt(static_cast(x))); +}; + +CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); }; + +CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); }; + +template +CK_TILE_HOST T tanh(T x) +{ + return type_convert(std::tanhf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float tanh(float x) +{ + return std::tanhf(x); +}; + +template <> +CK_TILE_HOST double tanh(double x) +{ + return std::tanh(x); +}; + +template +CK_TILE_HOST T acos(T x) +{ + return type_convert(std::acosf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float acos(float x) +{ + return std::acosf(x); +}; + +template <> +CK_TILE_HOST double acos(double x) +{ + return std::acos(x); +}; + +template +CK_TILE_HOST T neg(T x) +{ + return type_convert(-(type_convert(x))); +}; + +template <> +CK_TILE_HOST float neg(float x) +{ + return -x; +}; + +template <> +CK_TILE_HOST double neg(double x) +{ + return -x; +}; + +template <> +CK_TILE_HOST int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +CK_TILE_HOST int8_t neg(int8_t x) +{ + return -x; +}; + +template +CK_TILE_HOST T atan(T x) +{ + return type_convert(std::atanf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float atan(float x) +{ + return std::atanf(x); +}; + +template <> +CK_TILE_HOST double atan(double x) +{ + return std::atan(x); +}; + +template +CK_TILE_HOST T sin(T x) +{ + return type_convert(std::sinf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float sin(float x) +{ + return std::sinf(x); +}; + +template <> +CK_TILE_HOST double sin(double x) +{ + return std::sin(x); +}; + +template +CK_TILE_HOST T asin(T x) +{ + return type_convert(std::asinf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float asin(float x) +{ + return std::asinf(x); +}; + +template <> +CK_TILE_HOST double asin(double x) +{ + return std::asin(x); +}; + +template +CK_TILE_HOST T asinh(T x) +{ + return type_convert(std::asinhf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float asinh(float x) +{ + return std::asinhf(x); +}; + +template <> +CK_TILE_HOST double asinh(double x) +{ + return std::asinh(x); +}; + +template +CK_TILE_HOST T cos(T x) +{ + return type_convert(std::cosf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float cos(float x) +{ + return std::cosf(x); +}; + +template <> +CK_TILE_HOST double cos(double x) +{ + return std::cos(x); +}; + +template +CK_TILE_HOST T acosh(T x) +{ + return type_convert(std::acoshf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float acosh(float x) +{ + return std::acoshf(x); +}; + +template <> +CK_TILE_HOST double acosh(double x) +{ + return std::acosh(x); +}; + +template +CK_TILE_HOST T tan(T x) +{ + return type_convert(std::tanf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float tan(float x) +{ + return std::tanf(x); +}; + +template <> +CK_TILE_HOST double tan(double x) +{ + return std::tan(x); +}; + +template +CK_TILE_HOST T atanh(T x) +{ + return type_convert(std::atanhf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float atanh(float x) +{ + return std::atanhf(x); +}; + +template <> +CK_TILE_HOST double atanh(double x) +{ + return std::atanh(x); +}; + +template +CK_TILE_HOST T sinh(T x) +{ + return type_convert(std::sinhf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float sinh(float x) +{ + return std::sinhf(x); +}; + +template <> +CK_TILE_HOST double sinh(double x) +{ + return std::sinh(x); +}; + +template +CK_TILE_HOST T ceil(T x) +{ + return type_convert(std::ceilf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float ceil(float x) +{ + return std::ceilf(x); +}; + +template <> +CK_TILE_HOST double ceil(double x) +{ + return std::ceil(x); +}; + +template +CK_TILE_HOST T cosh(T x) +{ + return type_convert(std::coshf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float cosh(float x) +{ + return std::coshf(x); +}; + +template <> +CK_TILE_HOST double cosh(double x) +{ + return std::cosh(x); +}; + +template +CK_TILE_HOST T floor(T x) +{ + return type_convert(std::floorf(type_convert(x))); +}; + +template <> +CK_TILE_HOST float floor(float x) +{ + return std::floorf(x); +}; + +template <> +CK_TILE_HOST double floor(double x) +{ + return std::floor(x); +}; + +template +CK_TILE_HOST T rcp(T x) +{ + return type_convert(1.f / type_convert(x)); +}; + +template +CK_TILE_HOST T exp(T x) +{ + return type_convert(std::expf(type_convert(x))); +} + +template <> +CK_TILE_HOST float exp(float x) +{ + return std::expf(x); +} + +template <> +CK_TILE_HOST double exp(double x) +{ + return std::exp(x); +} + +template +CK_TILE_HOST T log(T x) +{ + return type_convert(std::logf(type_convert(x))); +} + +template <> +CK_TILE_HOST float log(float x) +{ + return std::logf(x); +} + +template <> +CK_TILE_HOST double log(double x) +{ + return std::log(x); +} + +template +CK_TILE_HOST T pow(T x, T gamma) +{ + return type_convert(std::powf(type_convert(x), type_convert(gamma))); +} + +template <> +CK_TILE_HOST float pow(float x, float gamma) +{ + return std::powf(x, gamma); +} + +template <> +CK_TILE_HOST double pow(double x, double gamma) +{ + return std::pow(x, gamma); +} + +template +CK_TILE_HOST T expm1(T x) +{ + return type_convert(std::expm1f(type_convert(x))); +} + +template <> +CK_TILE_HOST float expm1(float x) +{ + return std::expm1f(x); +} + +template <> +CK_TILE_HOST double expm1(double x) +{ + return std::expm1(x); +} + +// math functions for the HIP kernel, some are implemented by calling hip builtin functions + +CK_TILE_DEVICE float abs(float x) +{ + union + { + float f32; + uint32_t u32; + } y; + y.f32 = x; + y.u32 = y.u32 & 0x7fffffff; + return y.f32; +}; + +CK_TILE_DEVICE double abs(double x) { return ::abs(x); }; + +CK_TILE_DEVICE int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +CK_TILE_DEVICE int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +CK_TILE_DEVICE int4_t abs(int4_t x) +{ + int4_t sgn = x >> (4 - 1); + + return (x ^ sgn) - sgn; +}; +#endif + +CK_TILE_DEVICE fp16_t abs(fp16_t x) +{ + uint16_t xx = bit_cast(x); + + uint16_t abs_xx = xx & 0x7fff; + + fp16_t abs_x = bit_cast(abs_xx); + + return abs_x; +}; + +CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); }; + +CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); }; + +CK_TILE_DEVICE bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +CK_TILE_DEVICE bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +CK_TILE_DEVICE bool isnan(int4_t x) +{ + (void)x; + return false; +}; +#endif + +CK_TILE_DEVICE bool isnan(fp16_t x) +{ + uint16_t xx = bit_cast(x); + + return (xx & 0x7FFF) > 0x7C00; +}; + +CK_TILE_DEVICE fp16_t sqrt(fp16_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + +CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; + +CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; + +template +CK_TILE_DEVICE T tanh(T x) +{ + return type_convert(::tanhf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float tanh(float x) +{ + return ::tanhf(x); +}; + +template <> +CK_TILE_DEVICE double tanh(double x) +{ + return ::tanh(x); +}; + +template +CK_TILE_DEVICE T acos(T x) +{ + return type_convert(::acosf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float acos(float x) +{ + return ::acosf(x); +}; + +template <> +CK_TILE_DEVICE double acos(double x) +{ + return ::acos(x); +}; + +template +CK_TILE_DEVICE T neg(T x) +{ + return type_convert(-(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float neg(float x) +{ + return -x; +}; + +template <> +CK_TILE_DEVICE double neg(double x) +{ + return -x; +}; + +template <> +CK_TILE_DEVICE int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +CK_TILE_DEVICE int8_t neg(int8_t x) +{ + return -x; +}; + +template <> +CK_TILE_DEVICE fp16_t neg(fp16_t x) +{ + return __hneg(x); +}; + +template +CK_TILE_DEVICE T atan(T x) +{ + return type_convert(::atanf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float atan(float x) +{ + return ::atanf(x); +}; + +template <> +CK_TILE_DEVICE double atan(double x) +{ + return ::atan(x); +}; + +template +CK_TILE_DEVICE T sin(T x) +{ + return type_convert(::sinf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float sin(float x) +{ + return ::sinf(x); +}; + +template <> +CK_TILE_DEVICE double sin(double x) +{ + return ::sin(x); +}; + +template <> +CK_TILE_DEVICE fp16_t sin(fp16_t x) +{ + return ::hsin(x); +}; + +template +CK_TILE_DEVICE T asin(T x) +{ + return type_convert(::asinf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float asin(float x) +{ + return ::asinf(x); +}; + +template <> +CK_TILE_DEVICE double asin(double x) +{ + return ::asin(x); +}; + +template +CK_TILE_DEVICE T asinh(T x) +{ + return type_convert(::asinhf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float asinh(float x) +{ + return ::asinhf(x); +}; + +template <> +CK_TILE_DEVICE double asinh(double x) +{ + return ::asinh(x); +}; + +template +CK_TILE_DEVICE T acosh(T x) +{ + return type_convert(::acoshf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float acosh(float x) +{ + return ::acoshf(x); +}; + +template <> +CK_TILE_DEVICE double acosh(double x) +{ + return ::acosh(x); +}; + +template +CK_TILE_DEVICE T tan(T x) +{ + return type_convert(::tanf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float tan(float x) +{ + return ::tanf(x); +}; + +template <> +CK_TILE_DEVICE double tan(double x) +{ + return ::tan(x); +}; + +template +CK_TILE_DEVICE T atanh(T x) +{ + return type_convert(::atanhf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float atanh(float x) +{ + return ::atanhf(x); +}; + +template <> +CK_TILE_DEVICE double atanh(double x) +{ + return ::atanh(x); +}; + +template +CK_TILE_DEVICE T sinh(T x) +{ + return type_convert(::sinhf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float sinh(float x) +{ + return ::sinhf(x); +}; + +template <> +CK_TILE_DEVICE double sinh(double x) +{ + return ::sinh(x); +}; + +template +CK_TILE_DEVICE T ceil(T x) +{ + return type_convert(::ceilf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float ceil(float x) +{ + return ::ceilf(x); +}; + +template <> +CK_TILE_DEVICE double ceil(double x) +{ + return ::ceil(x); +}; + +template <> +CK_TILE_DEVICE fp16_t ceil(fp16_t x) +{ + return ::hceil(x); +}; + +template +CK_TILE_DEVICE T cosh(T x) +{ + return type_convert(::coshf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float cosh(float x) +{ + return ::coshf(x); +}; + +template <> +CK_TILE_DEVICE double cosh(double x) +{ + return ::cosh(x); +}; + +template +CK_TILE_DEVICE T floor(T x) +{ + return type_convert(::floorf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float floor(float x) +{ + return ::floorf(x); +}; + +template <> +CK_TILE_DEVICE double floor(double x) +{ + return ::floor(x); +}; + +template <> +CK_TILE_DEVICE fp16_t floor(fp16_t x) +{ + return ::hfloor(x); +}; + +template +CK_TILE_DEVICE T rcp(T x) +{ +#if !CK_TILE_WORKAROUND_SWDEV_383542 + return __frcp_rn(x); +#else + // return __ocml_native_recip_f32(x); + return __builtin_amdgcn_rcpf(x); +#endif +}; + +template +CK_TILE_DEVICE T exp(T x) +{ + return type_convert(__ocml_exp_f32(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE fp16_t exp(fp16_t x) +{ + return hexp(x); +}; + +template <> +CK_TILE_DEVICE float exp(float x) +{ + return __ocml_exp_f32(x); +}; + +template <> +CK_TILE_DEVICE double exp(double x) +{ + return exp(x); +}; + +template +CK_TILE_DEVICE T log(T x) +{ + return type_convert(__logf(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE fp16_t log(fp16_t x) +{ + return hlog(x); +}; + +template <> +CK_TILE_DEVICE float log(float x) +{ + return __logf(x); +}; + +template <> +CK_TILE_DEVICE double log(double x) +{ + return log(x); +}; + +template +CK_TILE_DEVICE T pow(T x, T gamma) +{ + return type_convert(powf(type_convert(x), type_convert(gamma))); +}; + +template <> +CK_TILE_DEVICE float pow(float x, float gamma) +{ + return powf(x, gamma); +}; + +template <> +CK_TILE_DEVICE double pow(double x, double gamma) +{ + return pow(x, gamma); +}; + +template +CK_TILE_DEVICE T expm1(T x) +{ + return type_convert(expm1f(type_convert(x))); +}; + +template <> +CK_TILE_DEVICE float expm1(float x) +{ + return expm1f(x); +}; + +template <> +CK_TILE_DEVICE double expm1(double x) +{ + return expm1(x); +}; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ed705c91e7..2cc788d422 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -91,8 +91,10 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get(index_t i, bool is_valid_element, bool_constant = {}) const + CK_TILE_DEVICE constexpr auto get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -107,11 +109,11 @@ struct buffer_view(&p_data_[i]); + return *c_style_pointer_cast(&p_data_[i + linear_offset]); #endif } else @@ -134,17 +136,17 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { if constexpr(Op == memory_operation_enum::set) { - this->template set(i, is_valid_element, x); + this->template set(i, linear_offset, is_valid_element, x); } // FIXME: remove memory_operation_enum::add else if constexpr(Op == memory_operation_enum::add) { - auto tmp = this->template get(i, is_valid_element); - this->template set(i, is_valid_element, x + tmp); + auto tmp = this->template get(i, linear_offset, is_valid_element); + this->template set(i, linear_offset, is_valid_element, x + tmp); } } @@ -154,7 +156,7 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -169,9 +171,9 @@ struct buffer_view(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i + linear_offset]) = x; #endif } } @@ -276,8 +278,10 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get(index_t i, bool is_valid_element, bool_constant = {}) const + CK_TILE_DEVICE constexpr auto get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -303,7 +307,7 @@ struct buffer_view( - p_data_, i, is_valid_element, buffer_size_); + p_data_, i + linear_offset, is_valid_element, buffer_size_); } else { @@ -311,8 +315,11 @@ struct buffer_view, t_per_x, Coherence, - oob_conditional_check>( - p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); + oob_conditional_check>(p_data_, + i + linear_offset, + is_valid_element, + buffer_size_, + invalid_element_value_); } } else @@ -322,11 +329,11 @@ struct buffer_view(&p_data_[i]); + return *c_style_pointer_cast(&p_data_[i + linear_offset]); #endif } else @@ -352,7 +359,8 @@ struct buffer_view>::scalar_type>::value, bool>::type = false> CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, - index_t i, + index_t v_offset, + index_t i_offset, bool is_valid_element, bool_constant = {}) const { @@ -366,7 +374,38 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check, pre_nop>( - dst, cached_buf_res_, i, is_valid_element, bool_constant{}); + dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant{}); + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t* smem, + index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const + { + // X is vector of T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_async_buffer_load_with_oob, t_per_x, Coherence>( + smem, + cached_buf_res_, + i, + linear_offset, + is_valid_element, + bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -378,6 +417,7 @@ struct buffer_view::type = false> CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, index_t i, + index_t linear_offset, bool /*is_valid_element*/, bool_constant = {}) const { @@ -391,7 +431,7 @@ struct buffer_view, t_per_x, Coherence>( - smem, cached_buf_res_, i, bool_constant{}); + smem, cached_buf_res_, i, linear_offset, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -401,25 +441,25 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { if constexpr(Op == memory_operation_enum::set) { - this->template set(i, is_valid_element, x); + this->template set(i, linear_offset, is_valid_element, x); } else if constexpr(Op == memory_operation_enum::atomic_add) { - this->template atomic_add(i, is_valid_element, x); + this->template atomic_add(i, linear_offset, is_valid_element, x); } else if constexpr(Op == memory_operation_enum::atomic_max) { - this->template atomic_max(i, is_valid_element, x); + this->template atomic_max(i, linear_offset, is_valid_element, x); } // FIXME: remove memory_operation_enum::add else if constexpr(Op == memory_operation_enum::add) { - auto tmp = this->template get(i, is_valid_element); - this->template set(i, is_valid_element, x + tmp); + auto tmp = this->template get(i, linear_offset, is_valid_element); + this->template set(i, linear_offset, is_valid_element, x + tmp); // tmp += x; // this->template set(i, is_valid_element, tmp); } @@ -432,7 +472,7 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -453,7 +493,7 @@ struct buffer_view, t_per_x, Coherence>( - x, p_data_, i, is_valid_element, buffer_size_); + x, p_data_, i + linear_offset, is_valid_element, buffer_size_); } else { @@ -462,9 +502,9 @@ struct buffer_view(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i + linear_offset]) = x; #endif } } @@ -477,7 +517,7 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -489,7 +529,7 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - x, p_data_, i, is_valid_element, buffer_size_); + x, p_data_, i, linear_offset, is_valid_element, buffer_size_); } template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void + atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { using scalar_t = typename vector_traits>::scalar_type; @@ -532,13 +573,13 @@ struct buffer_view, t_per_x>( - x, p_data_, i, is_valid_element, buffer_size_); + x, p_data_, i + linear_offset, is_valid_element, buffer_size_); } else { if(is_valid_element) { - atomic_add_g, t_per_x>(&p_data_[i], x); + atomic_add_g, t_per_x>(&p_data_[i + linear_offset], x); } } } @@ -548,7 +589,8 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void + atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -572,11 +614,11 @@ struct buffer_view, t_per_x>( - x, p_data_, i, is_valid_element, buffer_size_); + x, p_data_, i + linear_offset, is_valid_element, buffer_size_); } else if(is_valid_element) { - atomic_max_g, t_per_x>(&p_data_[i], x); + atomic_max_g, t_per_x>(&p_data_[i + linear_offset], x); } } @@ -668,8 +710,10 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get(index_t i, bool is_valid_element, bool_constant = {}) const + CK_TILE_DEVICE constexpr auto get(index_t i, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -684,14 +728,14 @@ struct buffer_view>::scalar_type, scalar_per_t_vector * scalar_per_x_vector>; // using buf_t = ushort __attribute__((ext_vector_type(8))); - auto rtn = *c_style_pointer_cast(&p_data_[i]); + auto rtn = *c_style_pointer_cast(&p_data_[i + linear_offset]); return bit_cast(rtn); #endif } @@ -708,6 +752,23 @@ struct buffer_view>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t v_offset, + index_t i_offset, + bool /*is_valid_element*/, + bool_constant = {}) const + { + smem_load{}(dst, v_offset * sizeof(T), i_offset * sizeof(T)); + } + // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { if constexpr(Op == memory_operation_enum::set) { - this->template set(i, is_valid_element, x); + this->template set(i, linear_offset, is_valid_element, x); } // FIXME: remove memory_operation_enum::add else if constexpr(Op == memory_operation_enum::add) { - auto tmp = this->template get(i, is_valid_element); - this->template set(i, is_valid_element, x + tmp); + auto tmp = this->template get(i, linear_offset, is_valid_element); + this->template set(i, linear_offset, is_valid_element, x + tmp); } } @@ -735,7 +796,7 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -751,6 +812,7 @@ struct buffer_view>::scalar_type, int8_t>::value && workaround_int8_ds_write_issue) @@ -952,8 +1014,10 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get(index_t i, bool is_valid_element, bool_constant = {}) const + CK_TILE_DEVICE constexpr auto get(index_t i, + index_t /*linear_offset*/, + bool is_valid_element, + bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -995,17 +1059,17 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { if constexpr(Op == memory_operation_enum::set) { - this->template set(i, is_valid_element, x); + this->template set(i, linear_offset, is_valid_element, x); } // FIXME: remove memory_operation_enum::add else if constexpr(Op == memory_operation_enum::add) { - auto tmp = this->template get(i, is_valid_element); - this->template set(i, is_valid_element, x + tmp); + auto tmp = this->template get(i, linear_offset, is_valid_element); + this->template set(i, linear_offset, is_valid_element, x + tmp); } } @@ -1015,7 +1079,7 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -1030,9 +1094,9 @@ struct buffer_view(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i + linear_offset]) = x; #endif } } diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index aeda5e9c06..06b5a8da0b 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -12,6 +12,7 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/null_tile_window.hpp" #include "ck_tile/core/tensor/null_tensor.hpp" @@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, bool_constant = {}) { - return tile_window.load(bool_constant{}); + return tile_window.load(number<-1>{}, bool_constant{}); +} + +template +CK_TILE_DEVICE auto load_tile(const tile_window_linear& tile_window, + bool_constant = {}) +{ + return tile_window.load(number<-1>{}, bool_constant{}); } template = {}, bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}, bool_constant{}); + tile_window.load_raw( + tile, number<-1>{}, bool_constant{}, bool_constant{}); +} + +template +CK_TILE_DEVICE auto load_tile_raw(T& tile, + const tile_window_linear& tile_window, + bool_constant = {}, + bool_constant = {}) +{ + tile_window.load_raw( + tile, number<-1>{}, bool_constant{}, bool_constant{}); } template = {}) { return tile_window.async_load_raw( - lds_tile, bool_constant{}, bool_constant{}); + lds_tile, number<-1>{}, bool_constant{}, bool_constant{}); +} + +template +CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, + const tile_window_linear& tile_window, + bool_constant = {}, + bool_constant = {}) +{ + return tile_window.async_load_raw( + lds_tile, number<-1>{}, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index baf009add2..da3c7117e5 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT // get input vectors static_for<0, num_vec_in, 1>{}([&](auto i) { - constexpr auto idx_y_in = generate_array( + constexpr auto idx_y_in = generate_tuple( [&](auto ii) { return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; }, diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index 2efc657013..d5a716664d 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { - tile_window.store(dstr_tensor); + tile_window.store(dstr_tensor, number<-1>{}); } template & tile_window, const static_distributed_tensor& dstr_tensor) { - tile_window.store_raw(dstr_tensor); + tile_window.store_raw(dstr_tensor, number<-1>{}); +} + +template +CK_TILE_DEVICE void store_tile( + tile_window_linear& + tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store(dstr_tensor, number<-1>{}); +} + +template +CK_TILE_DEVICE void store_tile_raw( + tile_window_linear& + tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store_raw(dstr_tensor, number<-1>{}); } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 4655eec241..698ce5378d 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -16,6 +16,24 @@ namespace ck_tile { +/* + * tensor_view + * abstract the underneath memory buffer(global, LDS, etc...) + * and provide a unified get/set function for access + * + * For addressing into the buffer we use 2 variable to control: + * coord : ND tensor coordinate, will calculate the actual offset inside + * linear_offset : 1D offset, will be used in the immediate field of + * the buffer instruction to help reduce register usage + * + * User can use either of the field, or both to indexing into the tensor + * + * We usually provide 2 set of API for buffer get/set, e.g. + * get_vectorized_elements()/get_vectorized_elements_raw() + * the former usually will call intrinsic or normal C function, the later + * usually will call inline-asm function + * + */ template @@ -49,22 +67,6 @@ struct tensor_view CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; } -#if 0 - CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const - { - return buf_.template get( - coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); - } - - CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x) - { - buf_.template set( - coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), - x); - } -#endif // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template ::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, bool_constant = {}) const { return buf_.template get( coord.get_offset(), + linear_offset, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr remove_cvref_t + get_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, // flag + bool_constant = {}) const + { + return buf_.template get(coord.get_offset(), + linear_offset, + is_valid_element, + bool_constant{}); + } + // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template ::type = false> CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, const TensorCoord& coord, + index_t linear_offset, bool_constant = {}, bool_constant = {}) const { return buf_.template get_raw( dst, coord.get_offset(), + linear_offset, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}, + bool_constant = {}) const + { + return buf_.template get_raw( + dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t* smem, + const TensorCoord& coord, + index_t linear_offset) const + { + return buf_.template async_get( + smem, + coord.get_offset(), + linear_offset, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t* smem, + const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element) const + { + return buf_.template async_get(smem, + coord.get_offset(), + linear_offset, + is_valid_element, + bool_constant{}); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements_raw(remove_cvref_t* smem, + const TensorCoord& coord, + index_t linear_offset, + bool_constant = {}) const + { + return buf_.template async_get_raw( + smem, + coord.get_offset(), + linear_offset, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -110,11 +210,15 @@ struct tensor_view std::is_same_v>::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( - remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements_raw(remove_cvref_t* smem, + const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + bool_constant = {}) const { return buf_.template async_get_raw( - smem, coord.get_offset(), true /*not used*/, bool_constant{}); + smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); } // X is vector of DataType. @@ -125,11 +229,15 @@ struct tensor_view std::is_same_v>::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( - const TensorCoord& coord, const X& x, bool_constant = {}) + CK_TILE_HOST_DEVICE constexpr void + set_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + const X& x, + bool_constant = {}) { buf_.template set( coord.get_offset(), + linear_offset, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -140,15 +248,53 @@ struct tensor_view std::is_same_v>::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( - const TensorCoord& coord, const X& x, bool_constant = {}) + CK_TILE_HOST_DEVICE constexpr void + set_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}) + { + buf_.template set( + coord.get_offset(), linear_offset, is_valid_element, x); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + set_vectorized_elements_raw(const TensorCoord& coord, + index_t linear_offset, + const X& x, + bool_constant = {}) { buf_.template set_raw( coord.get_offset(), + linear_offset, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + set_vectorized_elements_raw(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}) + { + buf_.template set_raw( + coord.get_offset(), linear_offset, is_valid_element, x); + } + // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( - const TensorCoord& coord, const X& x, bool_constant = {}) + CK_TILE_HOST_DEVICE constexpr void + update_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + const X& x, + bool_constant = {}) { buf_.template update( coord.get_offset(), + linear_offset, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + update_vectorized_elements(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}) + { + buf_.template update( + coord.get_offset(), linear_offset, is_valid_element, x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 266d623c71..ca35078275 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -18,6 +18,8 @@ namespace ck_tile { +// Note: this tile window do not support single issue +// you need to use tile_window_linear structure for this purpose template {}; static constexpr auto I1 = number<1>{}; + static_assert(NumCoord == 1); // TODO: check WindowLengths and StaticTileDistribution are consistent @@ -189,7 +192,8 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_step_between(number<0>{}, number{}); - constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -222,10 +226,11 @@ struct tile_window_with_static_distribution // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( WindowAdaptorCoord& window_adaptor_thread_coord, BottomTensorCoord& bottom_tensor_thread_coord, - const AdaptorTopIndex& idx_diff_adaptor_top) const + const ATopIndex& idx_diff_adaptor_top) const { array idx_diff_adaptor_bottom; @@ -279,10 +284,11 @@ struct tile_window_with_static_distribution get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); } - CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; } + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } - template - CK_TILE_DEVICE auto load(bool_constant = {}) const + template + CK_TILE_DEVICE auto load(number = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -308,11 +314,11 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, bool_constant{}); + bottom_tensor_thread_coord, 0, bool_constant{}); #if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { - constexpr auto idx_ys = generate_array( + constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; @@ -338,8 +344,9 @@ struct tile_window_with_static_distribution { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -350,8 +357,12 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, + number = {}, bool_constant = {}, bool_constant = {}) const { @@ -397,6 +408,7 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, + 0 /**/, bool_constant{}, pre_nop_); #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ @@ -409,23 +421,24 @@ struct tile_window_with_static_distribution { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); }); -#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE - asm volatile("; this inline asm is workaround to prevent compiler from using too much " - "scratch memory" ::); -#endif } // TODO: currently async load only implemented in inline asm - template + template CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + number = {}, bool_constant = {}, bool_constant = {}) const { @@ -467,7 +480,7 @@ struct tile_window_with_static_distribution // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { - // TODO: use structure binding (to be captured later) if compiled in C++20 + /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; @@ -482,15 +495,16 @@ struct tile_window_with_static_distribution // read from bottom tensor get_bottom_tensor_view().template async_get_vectorized_elements_raw( - smem, bottom_tensor_thread_coord, pre_nop_); + smem, bottom_tensor_thread_coord, 0, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -501,8 +515,81 @@ struct tile_window_with_static_distribution }); } - template + template + CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out + // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to + // check?) + constexpr index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})); + + constexpr index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) - + size_per_buf; + + constexpr index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + // TODO: we force CK_TILE_LDS_ADDR + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements( + smem, bottom_tensor_thread_coord, 0, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + smem += size_per_issue; // Note we manually increase the per-issue offset + } + }); + }); + } + + template CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, bool_constant = {}) const { using Traits = load_store_traits; @@ -515,7 +602,6 @@ struct tile_window_with_static_distribution // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { - /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; @@ -530,7 +616,7 @@ struct tile_window_with_static_distribution vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { - constexpr auto idx_ys = generate_array( + constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; @@ -548,15 +634,19 @@ struct tile_window_with_static_distribution // write into bottom tensor get_bottom_tensor_view().template set_vectorized_elements( - bottom_tensor_thread_coord, vec_value, bool_constant{}); + bottom_tensor_thread_coord, + 0, + vec_value, + bool_constant{}); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -565,8 +655,9 @@ struct tile_window_with_static_distribution }); } - CK_TILE_DEVICE void - store_raw(const static_distributed_tensor& dstr_tensor) const + template + CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, + number = {}) const { using Traits = load_store_traits; @@ -591,7 +682,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { - constexpr auto idx_ys = generate_array( + constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; @@ -606,15 +697,16 @@ struct tile_window_with_static_distribution // write into bottom tensor get_bottom_tensor_view() .template set_vectorized_elements_raw( - bottom_tensor_thread_coord, vec_value); + bottom_tensor_thread_coord, 0, vec_value); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -623,8 +715,9 @@ struct tile_window_with_static_distribution }); } - template + template CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + number = {}, bool_constant = {}) const { using Traits = load_store_traits; @@ -650,7 +743,7 @@ struct tile_window_with_static_distribution vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { - constexpr auto idx_ys = generate_array( + constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; @@ -666,15 +759,19 @@ struct tile_window_with_static_distribution // write into bottom tensor get_bottom_tensor_view().template update_vectorized_elements( - bottom_tensor_thread_coord, vec_value, bool_constant{}); + bottom_tensor_thread_coord, + 0, + vec_value, + bool_constant{}); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto idx_diff_ps_ys = - container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -746,7 +843,8 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_step_between(number<0>{}, number{}); - constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); @@ -798,6 +896,27 @@ make_tile_window(const TensorView_& tensor_view, tensor_view, window_lengths, origin, tile_distribution}; } +// this version can't be called in a constexpr context +template +CK_TILE_DEVICE auto +make_tile_window_raw(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + number = {}) +{ + auto w = tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution}; + w.init_raw(); + return w; +} + template +CK_TILE_DEVICE constexpr auto +make_tile_window_raw(const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution) +{ + auto w = make_tile_window(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution); + w.init_raw(); + return w; +} + template CK_TILE_DEVICE void move_tile_window( tile_window_with_static_lengths& window, diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp new file mode 100644 index 0000000000..4b921ec5b9 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -0,0 +1,1082 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +#define WINDOW_DISPATCH_ISSUE() \ + if constexpr(i_access < 0) \ + { \ + static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \ + } \ + else \ + { \ + static_assert(i_access < NumAccess); \ + issue(number{}); \ + } + +// +// This version of tile window will pre-cache offset/flags based on need +// +// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim +// so last dim can use immediate offset to indexing, can save register +// TODO: if using this struct, better use load_raw()/store_raw(), can control +// the the immediate offset on the fly +// space-filing-curve is non-snaked here! +// +template +struct tile_window_linear +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + using LinearBottomDims = remove_cvref_t; + + static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension()); + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct traits + { + private: + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, + bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = + bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = + bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array + window_adaptor_vector_lengths{-1}; + array + window_adaptor_vector_strides{-1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = + typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto thread_tensor_lengths_ys = + to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + + private: + static constexpr auto get_num_non_linear_access() + { + constexpr auto sfc_access_lens = SFC_Ys::access_lengths; + using ys_to_rhs_major = + typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + + constexpr auto non_linear = [&]() { + index_t cnt = 1; + static_for<0, NDimY, 1>{}([&](auto i_dim_y) { + constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; + constexpr auto target_h_dim = number{}; // no r dim here! + if constexpr(LinearBottomDims{}[target_h_dim] == 0) + { + cnt *= sfc_access_lens[i_dim_y]; + } + }); + return cnt; + }(); + + return non_linear; + } + + // example: + // non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register + // used + // -> histogram : sequence<4, 4> + // -> prefixsum : seqneuce<0, 4, 8> + // non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register + // used, will pre-cache 8 + // -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1> + // -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8> + // non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register + // used, will pre-cache 4 + // -> histogram : sequence<2, 2, 2, 2> + // -> prefixsum : seqneuce<0, 2, 4, 6, 8> + static constexpr auto get_non_linear_access_map() + { + constexpr auto sfc_access_lens = SFC_Ys::access_lengths; + using ys_to_rhs_major = + typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + constexpr auto non_linear_map = [&]() { + array m_{0}; + index_t cumulative_len_ = 1; + index_t cumulative_non_linear_len_ = 1; + static_for<0, NDimY, 1>{}([&](auto i_y) { + constexpr auto i_dim_y = number{}; // from right to left + constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; + constexpr auto target_h_dim = number{}; // no r dim here! + constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim]; + + array current_m_{0}; + constexpr auto current_len_ = sfc_access_lens[i_dim_y]; + + // copy cumulative length as current pattern + for(auto i_ = 0; i_ < cumulative_len_; i_++) + { + current_m_(i_) = m_[i_]; + } + for(auto j_ = 0; j_ < current_len_; j_++) + { + auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_; + for(auto i_ = 0; i_ < cumulative_len_; i_++) + { + m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_; + } + } + cumulative_len_ *= current_len_; + if(!is_linear_dim) + cumulative_non_linear_len_ *= current_len_; + }); + return m_; + }(); + + return TO_SEQUENCE(non_linear_map, NumAccess); + } + + static constexpr auto get_non_linear_access_histogram() + { + constexpr auto m_ = get_non_linear_access_map(); + // m_.foo(); + + constexpr auto r_ = + typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{}; + + constexpr auto h_ = histogram_sorted_sequence(m_, r_); + + return h_; + } + + static constexpr auto get_non_linear_access_histogram_prefix_sum() + { + constexpr auto h_ = get_non_linear_access_histogram(); + constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_); + return h_prefix_sum_; + } + + public: + static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access(); + using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence + using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram()); + using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum()); + }; + + static constexpr index_t NumAccess = traits::NumAccess; + static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear; + using AccessMap_NonLinear = typename traits::AccessMap_NonLinear; + using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear; + using AccessPrefixSum_NonLinear = typename traits::AccessPrefixSum_NonLinear; + + CK_TILE_DEVICE constexpr tile_window_linear() = default; + + CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + cached_coords_{}, + cached_flags_{} + { + auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + + auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // future load/store() calls (might allocate more registers) + using SFC_Ys = typename traits::SFC_Ys; + + static_for<0, NumAccess, 1>{}([&](auto i_access) { + constexpr auto non_linear_id = number{}; + constexpr auto need_save_non_linear_coord = + bool_constant{}; + + if constexpr(need_save_non_linear_coord) + { + cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + } + + // TODO: need pad_tensor_view to check which dim need use flag to check + // cached flag is independent from non-linear-coord + // but need be updated in move_tile, with proper dims + cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp); + + if constexpr(i_access != (NumAccess - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord_tmp, + bottom_tensor_thread_coord_tmp, + idx_diff_ps_ys); + } + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + template + CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number) + { + using SFC_Ys = typename traits::SFC_Ys; + constexpr auto idx_ys = SFC_Ys::get_index(number{}); + using ys_to_rhs_major = + typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + + constexpr auto modified_idx_ys = generate_tuple( + [&](auto i_dim_y) { + constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; + constexpr auto target_h_dim = number{}; // no r dim here! + if constexpr(LinearBottomDims{}[target_h_dim] == 0) + { + return number<0>{}; + } + else + { + return number{}; + } + }, + number{}); + + constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor(); + constexpr auto idx_ = + container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys); + + return adaptor_.calculate_bottom_index(idx_); + } + + template + CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number) + { + constexpr auto linear_coord = get_bottom_linear_coordinate(number{}); + // since this is linear offset, we assum bottom X tensor is always linear + constexpr index_t linear_offset = [&]() { + constexpr auto x_idx_ = linear_coord; + constexpr auto x_len_ = TileDstr{}.get_lengths(); + static_assert(x_idx_.size() == x_len_.size()); + constexpr index_t x_dims_ = x_idx_.size(); + index_t cu_stride_ = 1; + index_t cu_offset_ = 0; + static_for<0, x_dims_, 1>{}([&](auto i_) { + auto r_i_ = number{}; + cu_offset_ += x_idx_[r_i_] * cu_stride_; + cu_stride_ *= x_len_[r_i_]; + }); + return cu_offset_; + }(); + + return linear_offset; + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto bottom_tensor_flag = cached_flags_[IAccess]; + + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + linear_offset, + bottom_tensor_flag, + bool_constant{}); +#if 1 + // data index [y0, y1, ...] + constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); + // write into distributed tensor + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + }, + number{}); + + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j]; + }); +#else + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + }; + + WINDOW_DISPATCH_ISSUE(); + + return dst_tensor; + } + + template + CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, + number = {}, // negative means loop over all num_access + bool_constant = {}, + bool_constant = {}) const + { + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + static constexpr index_t YElementSize = + TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + static_assert(YElementSize % traits::ScalarPerVector == 0); + using vectorized_tbuf = array; + + constexpr auto tile_dstr = TileDstr{}; + + auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); + + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && i_access_ == 0 && + BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + auto bottom_tensor_flag = cached_flags_[IAccess]; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % traits::ScalarPerVector == 0); + + get_bottom_tensor_view().template get_vectorized_elements_raw( + dst_vec_tbuf.template at(), + bottom_tensor_thread_coord, + linear_offset /**/, + bottom_tensor_flag, + bool_constant{}, + pre_nop_); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ + CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE + asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag +#endif + }; + + WINDOW_DISPATCH_ISSUE(); + } + + // TODO: currently async load only implemented in inline asm + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + // currently we only support everything is non linear dim + // actually it's not performant if we have linear dim(e.g. fast changing) + static_assert(NumAccess_NonLinear == NumAccess); + static_assert(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global); + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using vector_t = typename traits::vector_t; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && i_access_ == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway + + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); + + // move thread coordinate + if constexpr(i_access_ != (NumAccess - 1)) + { + m0_inc_with_memory(size_per_issue); + } + }; + + WINDOW_DISPATCH_ISSUE(); + } + + template + CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + // currently we only support everything is non linear dim + // actually it's not performant if we have linear dim(e.g. fast changing) + static_assert(NumAccess_NonLinear == NumAccess); + static_assert(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global); + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out + // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to + // check?) + constexpr index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})); + + constexpr index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) - + size_per_buf; + + constexpr index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + + using vector_t = typename traits::vector_t; + + // TODO: we force CK_TILE_LDS_ADDR + CK_TILE_LDS_ADDR LdsDataType* smem = + lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto bottom_tensor_flag = cached_flags_[IAccess]; + + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements( + smem, + bottom_tensor_thread_coord, + 0, + bottom_tensor_flag, + bool_constant{}); + + // move thread coordinate + if constexpr(i_access_ != (NumAccess - 1)) + { + smem += size_per_issue; // Note we manually increase the per-issue offset + } + }; + + WINDOW_DISPATCH_ISSUE(); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + auto bottom_tensor_flag = cached_flags_[IAccess]; + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + linear_offset, + bottom_tensor_flag, + vec_value, + bool_constant{}); + }; + + WINDOW_DISPATCH_ISSUE(); + } + + template + CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, + number = {}) const + { + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + static constexpr bool oob_conditional_check = true; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + auto bottom_tensor_flag = cached_flags_[IAccess]; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + + // read from distributed tensor + vector_t vec_value; + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + }, + number{}); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view() + .template set_vectorized_elements_raw( + bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value); + }; + + WINDOW_DISPATCH_ISSUE(); + } + + template + CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + auto bottom_tensor_flag = cached_flags_[IAccess]; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements( + bottom_tensor_thread_coord, + linear_offset, + bottom_tensor_flag, + vec_value, + bool_constant{}); + }; + + WINDOW_DISPATCH_ISSUE(); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + static_for<0, NumAccess, 1>{}([&](auto i_access) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + constexpr auto need_update_non_linear_coord = + bool_constant{}; + + if constexpr(need_update_non_linear_coord) + { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + cached_coords_(non_linear_id), + step); + } + + // move the current coord with linear_coords + auto tmp_coords = cached_coords_[non_linear_id]; + constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess); + move_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord); + + cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid( + bottom_tensor_view_.get_tensor_descriptor(), tmp_coords); + }); + } + + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + + auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + TileDstr{}.get_ps_ys_to_xs_adaptor(), + container_concat(make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // future load/store() calls (might allocate more registers) + using SFC_Ys = typename traits::SFC_Ys; + + static_for<0, NumAccess, 1>{}([&](auto i_access) { + constexpr auto non_linear_id = number{}; + constexpr auto need_save_non_linear_coord = + bool_constant{}; + + if constexpr(need_save_non_linear_coord) + { + cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp; + } + + if constexpr(i_access != (NumAccess - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord_tmp, + bottom_tensor_thread_coord_tmp, + idx_diff_ps_ys); + } + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + // this contains: + array cached_coords_; + array cached_flags_; +}; + +#undef WINDOW_DISPATCH_ISSUE + +namespace impl { +template +struct default_linear_bottom_dims_impl +{ + using type = typename uniform_sequence_gen::type; +}; + +template +struct default_linear_bottom_dims_impl +{ + // global default to seq<0,0,....1> + using type = typename sequence_merge::type, + sequence<1>>::type; +}; + +template +struct default_linear_bottom_dims_impl +{ + // lds default to seq<1,1.....1> + using type = typename uniform_sequence_gen::type; +}; +} // namespace impl + +template +using default_linear_bottom_dims = + typename impl::default_linear_bottom_dims_impl::type; + +// if using this API, will create a tile_window_linear +// this structure can have the chance to use immediate value, save register +// need pass in LinearBottomDims_ properly to control which dim is linear +// so to generate a constexpr offset as linear_offset for this dim +// (and finally pass to the immediate offset of buffer/lds instruction) +// +// Note: there is no internal check for which dim is OK to use linear offset +// user must make sure by themselves +// +// e.g. +// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate +// immediate offset if each thread has multiple issue along last dim +// +// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset +// everything else is just using immediate offset. +// +template > +CK_TILE_DEVICE constexpr auto +make_tile_window_linear(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + LinearBottomDims_ = {}) +{ + static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension()); + return tile_window_linear, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{ + tensor_view, window_lengths, origin, tile_distribution}; +} + +template < + typename TileWindow_, + typename StaticTileDistribution_, + typename LinearBottomDims_ = default_linear_bottom_dims> +CK_TILE_DEVICE constexpr auto +make_tile_window_linear(const TileWindow_& tile_window, + const StaticTileDistribution_& tile_distribution, + LinearBottomDims_ = {}) +{ + return make_tile_window_linear(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + LinearBottomDims_{}); +} + +// this version must not be called under a constexpr context +template > +CK_TILE_DEVICE auto +make_tile_window_linear_raw(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + LinearBottomDims_ = {}) +{ + static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension()); + auto w = tile_window_linear, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{ + tensor_view, window_lengths, origin, tile_distribution}; + w.init_raw(); + return w; +} + +template < + typename TileWindow_, + typename StaticTileDistribution_, + typename LinearBottomDims_ = default_linear_bottom_dims> +CK_TILE_DEVICE constexpr auto +make_tile_window_linear_raw(const TileWindow_& tile_window, + const StaticTileDistribution_& tile_distribution, + LinearBottomDims_ = {}) +{ + return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + LinearBottomDims_{}); +} + +template +CK_TILE_DEVICE void move_tile_window( + tile_window_linear& + window, + const typename tile_window_linear::BottomTensorIndex& step) +{ + window.move(step); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/magic_div.hpp b/include/ck_tile/core/utility/magic_div.hpp index 09038ba296..fd9c733c52 100644 --- a/include/ck_tile/core/utility/magic_div.hpp +++ b/include/ck_tile/core/utility/magic_div.hpp @@ -59,8 +59,16 @@ struct magic_division32_bit_range CK_TILE_DEVICE static constexpr uint32_t do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) { - uint32_t tmp = __umulhi(dividend, multiplier); - return (tmp + dividend) >> shift; + if(__builtin_is_constant_evaluated()) + { + uint32_t tmp = (static_cast(dividend) * multiplier) >> 32; + return (tmp + dividend) >> shift; + } + else + { + uint32_t tmp = __umulhi(dividend, multiplier); + return (tmp + dividend) >> shift; + } } CK_TILE_HOST static constexpr uint32_t @@ -77,9 +85,18 @@ struct magic_division32_bit_range CK_TILE_DEVICE static constexpr int32_t do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { - uint32_t dividend_u32 = bit_cast(dividend_i32); - uint32_t tmp = __umulhi(dividend_u32, multiplier); - return (tmp + dividend_u32) >> shift; + if(__builtin_is_constant_evaluated()) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (static_cast(dividend_u32) * multiplier) >> 32; + return (tmp + dividend_u32) >> shift; + } + else + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = __umulhi(dividend_u32, multiplier); + return (tmp + dividend_u32) >> shift; + } } CK_TILE_HOST static constexpr int32_t diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index dbc1f5d23a..e17d7c22a2 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -24,5 +24,6 @@ #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" +#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index f490bbdeba..335911860a 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ck_tile/core.hpp" @@ -41,6 +42,73 @@ struct FillUniformDistribution } }; +namespace impl { + +// clang-format off +template struct RawIntegerType_ {}; +template<> struct RawIntegerType_<1> { using type = uint8_t;}; +template<> struct RawIntegerType_<2> { using type = uint16_t;}; +template<> struct RawIntegerType_<4> { using type = uint32_t;}; +template<> struct RawIntegerType_<8> { using type = uint64_t;}; +// clang-format on + +template +using RawIntegerType = typename RawIntegerType_::type; +} // namespace impl + +// Note: this struct will have no const-ness will generate random +template +struct FillUniformDistribution_Unique +{ + float a_{-5.f}; + float b_{5.f}; + std::optional seed_{11939}; + + std::mt19937 gen_{}; + std::unordered_set> set_{}; + + FillUniformDistribution_Unique(float a = -5.f, + float b = 5.f, + std::optional seed = {11939}) + : a_(a), + b_(b), + seed_(seed), + gen_{seed_.has_value() ? *seed_ : std::random_device{}()}, + set_{} + { + } + + template + void operator()(ForwardIter first, ForwardIter last) + { + std::mt19937& gen = gen_; + std::uniform_real_distribution dis(a_, b_); + auto& set = set_; + std::generate(first, last, [&dis, &gen, &set]() { + T v = static_cast(0); + do + { + v = ck_tile::type_convert(dis(gen)); + } while(set.count(bit_cast>(v)) == 1); + set.insert(bit_cast>(v)); + + return v; + }); + } + + template + auto operator()(ForwardRange&& range) + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } + + void clear() { set_.clear(); } +}; + template struct FillNormalDistribution { diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index f533d5c189..5610ba324d 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "ck_tile/core.hpp" #include "ck_tile/host/ranges.hpp" @@ -545,6 +546,28 @@ struct HostTensor typename Data::size_type size() const { return mData.size(); } + // return a slice of this tensor + // for simplicity we just copy the data and return a new tensor + auto slice(std::vector s_begin, std::vector s_end) const + { + assert(s_begin.size() == s_end.size()); + assert(s_begin.size() == get_num_of_dimension()); + + std::vector s_len(s_begin.size()); + std::transform( + s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus{}); + HostTensor sliced_tensor(s_len); + + sliced_tensor.ForEach([&](auto& self, auto idx) { + std::vector src_idx(idx.size()); + std::transform( + idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus{}); + self(idx) = operator()(src_idx); + }); + + return sliced_tensor; + } + template auto AsSpan() const { diff --git a/include/ck_tile/host/reference/reference_softmax.hpp b/include/ck_tile/host/reference/reference_softmax.hpp index f1404f85a8..d86e879944 100644 --- a/include/ck_tile/host/reference/reference_softmax.hpp +++ b/include/ck_tile/host/reference/reference_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,43 +9,81 @@ namespace ck_tile { -template -CK_TILE_HOST void reference_softmax(const HostTensor& a_m_n, - HostTensor& b_m_n) +template +CK_TILE_HOST void +reference_softmax(const HostTensor& x, HostTensor& y, index_t dim = -1) { - auto f = [&](auto m) { - const int N = a_m_n.mDesc.get_lengths()[1]; + index_t rank = x.get_num_of_dimension(); + assert(rank == y.get_num_of_dimension()); + assert(dim == -1 || dim < rank); - AccDataType v_max = ck_tile::numeric::Lowest(); + index_t target_dim = dim == -1 ? (rank - 1) : dim; + index_t softmax_len = x.get_length(target_dim); + index_t n_parallel = x.get_element_size() / softmax_len; + auto x_len = x.get_lengths(); - // max - for(int n = 0; n < N; ++n) + auto f = [&](auto i_element) { + std::vector coord = [&]() { + std::vector t_(rank, 0); + size_t r = i_element; + for(index_t i = rank - 1; i >= 0; i--) + { + if(i == target_dim) + continue; + t_[i] = r % x_len[i]; + r = r / x_len[i]; + } + return t_; + }(); + + ComputeType v_max = -ck_tile::numeric::infinity(); + + // compute max + for(auto idx = 0; idx < softmax_len; idx++) { - const ADataType v_a = a_m_n(m, n); - - v_max = v_max < v_a ? v_a : v_max; + auto c_ = coord; + c_[target_dim] = idx; + const ComputeType v_x = ck_tile::type_convert(x(c_)); + v_max = v_max < v_x ? v_x : v_max; } - AccDataType v_exp_sum = 0; + ComputeType v_exp_sum = static_cast(0); // sum - for(int n = 0; n < N; ++n) + for(auto idx = 0; idx < softmax_len; idx++) { - const ADataType v_a = a_m_n(m, n); + auto c_ = coord; + c_[target_dim] = idx; - v_exp_sum += ck_tile::exp(v_a - v_max); + const ComputeType v_x = ck_tile::type_convert(x(c_)); + + v_exp_sum += ck_tile::exp(v_x - v_max); } // elementwise - for(int n = 0; n < N; ++n) + for(auto idx = 0; idx < softmax_len; idx++) { - const ADataType v_a = a_m_n(m, n); + auto c_ = coord; + c_[target_dim] = idx; - b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum; + const ComputeType v_x = ck_tile::type_convert(x(c_)); + + auto out = ck_tile::exp(v_x - v_max) / v_exp_sum; + + y(c_) = ck_tile::type_convert(out); } }; - make_ParallelTensorFunctor(f, - b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency()); +} + +template +CK_TILE_HOST auto reference_softmax(const HostTensor& x, index_t dim = -1) +{ + HostTensor y(x.get_lengths(), x.get_strides()); + + reference_softmax(x, y, dim); + + return y; } } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_topk.hpp b/include/ck_tile/host/reference/reference_topk.hpp new file mode 100644 index 0000000000..3d0404a2e5 --- /dev/null +++ b/include/ck_tile/host/reference/reference_topk.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include +#include +#include +#include +#include + +namespace ck_tile { + +/* + similiar to torch.topk() + x (Tensor) – the input tensor. + k (int) – the k in “top-k” + dim (int, optional) – the dimension to sort along + largest (bool, optional) – largest or smallest elements + sorted (bool, optional) – elements in sorted order or not + + output: + y_values + y_indices + + https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h +*/ +template +CK_TILE_HOST void reference_topk(const HostTensor& x, + HostTensor& y_values, + HostTensor& y_indices, + index_t k, + index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + // rank must be the same + index_t rank = x.get_num_of_dimension(); + assert(rank == y_values.get_num_of_dimension()); + assert(rank == y_indices.get_num_of_dimension()); + assert(dim == -1 || dim < rank); + + index_t topk_dim = dim == -1 ? (rank - 1) : dim; + index_t topk_src_len = x.get_length(topk_dim); + auto x_len = x.get_lengths(); + + assert(k <= topk_src_len); + assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim)); + + index_t n_parallel = x.get_element_size() / topk_src_len; + + // clang-format off + auto f = [&](auto i_element) { + std::vector topk_coord = [&](){ + std::vector t_(rank, 0); + size_t r = i_element; + for(index_t i = rank - 1; i >= 0; i--) { + if(i == topk_dim) continue; // topk dim should be zero + t_[i] = r % x_len[i]; r = r / x_len[i]; + } + return t_; + }(); + + using elem_t = std::pair; + std::vector q = [&](){ + std::vector t_(topk_src_len); + for(index_t i = 0; i < topk_src_len; i++) { + auto c_ = topk_coord; c_[topk_dim] = i; + t_[i].first = x(c_); t_[i].second = i; + } + return t_; + }(); + + // run topk + if(largest) { + std::nth_element(q.begin(), q.begin() + k - 1, q.end(), + [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; }); + if(sorted) { + std::sort(q.begin(), q.begin() + k - 1, + [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; }); + } + } else { + std::nth_element(q.begin(), q.begin() + k - 1, q.end(), + [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; }); + if(sorted) { + std::sort(q.begin(), q.begin() + k - 1, + [](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; }); + } + } + + // write out + for(index_t i = 0; i < k; i++) { + auto c_ = topk_coord; c_[topk_dim] = i; + y_values(c_) = q[i].first; y_indices(c_) = q[i].second; + } + }; + // clang-format on + + make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency()); +} + +// TODO: if using this method, the return tensor would be dense(no stride) +template +CK_TILE_HOST auto reference_topk(const HostTensor& x, + index_t k, + index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + auto lens = x.get_lengths(); + index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim; + assert(target_dim < lens.size()); + assert(k <= lens[target_dim]); + lens[target_dim] = k; + HostTensor y_values(lens); + HostTensor y_indices(lens); + + reference_topk(x, y_values, y_indices, k, dim, largest, sorted); + + return ck_tile::make_tuple(y_values, y_indices); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp new file mode 100644 index 0000000000..62ba9dc0b3 --- /dev/null +++ b/include/ck_tile/ops/elementwise.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp new file mode 100644 index 0000000000..01217e16ce --- /dev/null +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -0,0 +1,1163 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { +namespace element_wise { + +#if 0 +struct PassThroughPack2 +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const + { + auto t = type_convert(x); + y = type_convert(t); + } + constexpr const static bool is_pack2_invocable = true; +}; +#endif + +struct PassThrough +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(double& y, const double& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(double& y, const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp16_t& y, + const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(int32_t& y, const int32_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, + const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, + const ck_tile::bf16_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, + const ck_tile::fp16_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp16_t& y, + const int8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, + const int8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(uint8_t& y, const uint8_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(int8_t& y, const int32_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(int32_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const int8_t& x) const + { + y = type_convert(x); + } + +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + CK_TILE_HOST_DEVICE void operator()(int4_t& y, const int4_t& x) const + { + y = x; + } + template <> + CK_TILE_HOST_DEVICE void operator()(int4_t& y, const int& x) const + { + y = type_convert(x); + } +#endif + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, + const ck_tile::fp8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp8_t& y, + const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const + { + y = x; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, + const ck_tile::bf8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::bf8_t& y, + const float& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const + { + y = type_convert(x); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const + { + y = ck_tile::type_convert(x); + } +}; + +#if 0 +struct UnaryConvert +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + y = type_convert(x); + } +}; + +struct ConvertBF16RTN +{ + // convert to bf16 using round to nearest (rtn) + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(std::is_same_v, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + y = bf16_convert_rtn(x); + } +}; + +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + +struct ConvertF8RNE +{ + // convert to fp8 using rounding to nearest even + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + y = f8_convert_rne(x); + } +}; +#endif + +struct Scale +{ + CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {} + + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + y = ck_tile::type_convert(ck_tile::type_convert(x) * scale_); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const + { + y = ck_tile::type_convert(scale_) * x; + }; + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const + { + const float x_tmp = ck_tile::type_convert(x); + const float y_tmp = scale_ * x_tmp; + y = ck_tile::type_convert(y_tmp); + }; + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const float& x) const + { + y = scale_ * x; + }; + + template <> + CK_TILE_HOST_DEVICE void operator()(double& y, const double& x) const + { + y = scale_ * x; + }; + + template <> + CK_TILE_HOST_DEVICE void operator()(int8_t& y, const int8_t& x) const + { + y = ck_tile::type_convert(scale_ * ck_tile::type_convert(x)); + }; + + float scale_; +}; + +struct ScaleAndResetNaNToMinusInfinity +{ + CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {} + + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const float& x) const + { + y = ck_tile::isnan(x) ? -numeric::infinity() : scale_ * x; + }; + + float scale_; +}; + +struct UnaryDivide +{ + CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {} + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v +#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v +#endif + , + "Data type is not supported by this operation!"); + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::abs(x); + }; +}; + +struct UnarySqrt +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::sqrt(x); + }; +}; + +struct Relu +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const + { + float x_f32 = ck_tile::type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = ck_tile::type_convert(y_f32); + } +}; + +// Fast GeLU +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +// host code use higher accuracy "exp" and "div" +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function +struct FastGelu +{ + template + CK_TILE_HOST void operator()(Y& y, const X& x) const; + + template + CK_TILE_DEVICE void operator()(Y& y, const X& x) const; + + template <> + CK_TILE_HOST void operator()(float& y, const float& x) const + { + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); + } + + // device code, use lower precision "__ocml_exp_f32" and "rcp" + template <> + CK_TILE_DEVICE void operator()(float& y, const float& x) const + { + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __ocml_exp_f32(u); + + y = x * ck_tile::rcp(1.f + emu); + } + + template <> + CK_TILE_HOST void operator()(ck_tile::fp16_t& y, + const ck_tile::fp16_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + CK_TILE_DEVICE void operator()(ck_tile::fp16_t& y, + const ck_tile::fp16_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + CK_TILE_HOST void operator()(ck_tile::fp16_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + CK_TILE_DEVICE void operator()(ck_tile::fp16_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + CK_TILE_HOST void operator()(ck_tile::bf16_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + CK_TILE_DEVICE void operator()(ck_tile::bf16_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + CK_TILE_DEVICE void operator()(ck_tile::bf16_t& y, + const ck_tile::bf16_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + CK_TILE_HOST void operator()(ck_tile::bf16_t& y, + const ck_tile::bf16_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+erf(x/sqrt(2))) +struct Gelu +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(float& y, const float& x) const + { + y = 0.5f * x * (1.f + erf(float(0.70710678118f * x))); + } + + template <> + CK_TILE_HOST_DEVICE void + operator()(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const + { + y = ck_tile::fp16_t(0.5) * x * + (ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x)))); + } +}; + +struct Sigmoid +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + ck_tile::exp(-x)); + }; +}; + +struct Silu +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = x * (one / (one + ck_tile::exp(-x))); + }; +}; + +struct TanH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::tanh(x); + }; +}; + +struct ACos +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::acos(x); + }; +}; + +struct Neg +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::neg(x); + }; +}; + +struct ATan +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::atan(x); + }; +}; + +struct Sin +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::sin(x); + }; +}; + +struct ASinH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::asinh(x); + }; +}; + +struct Cos +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::cos(x); + }; +}; + +struct ACosH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::acosh(x); + }; +}; + +struct Tan +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::tan(x); + }; +}; + +struct ATanH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::atanh(x); + }; +}; + +struct SinH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::sinh(x); + }; +}; + +struct Ceil +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::ceil(x); + }; +}; + +struct Exp +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::exp(x); + }; +}; + +struct CosH +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::cosh(x); + }; +}; + +struct Floor +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::floor(x); + }; +}; + +struct Log +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::log(x); + }; +}; + +struct ASin +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::asin(x); + }; +}; + +struct Rcp +{ + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + y = ck_tile::rcp(x); + }; +}; + +struct Swish +{ + Swish(float beta = 1.0f) : beta_(beta) {} + + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck_tile::exp(bx))); + }; + + const float beta_; +}; + +struct SoftRelu +{ + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha; + } + const float alpha_; +}; + +struct Power +{ + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck_tile::pow(shifted_scaled_x, casted_gamma); + } + const float alpha_; + const float beta_; + const float gamma_; +}; + +struct ClippedRelu +{ + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x)); + } + const float alpha_; + const float beta_; +}; + +struct LeakyRelu +{ + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + const float alpha_; +}; + +struct Elu +{ + Elu(float alpha = 1.f) : alpha_(alpha){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck_tile::expm1(x); + } + const float alpha_; +}; + +struct Logistic +{ + Logistic(float alpha = 1.f) : alpha_(alpha){}; + + template + CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const + { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha); + } + const float alpha_; +}; + +struct ConvInvscale +{ + CK_TILE_HOST_DEVICE + ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp8_t& e, + const float& c) const + { + e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScale +{ + CK_TILE_HOST_DEVICE + ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp8_t& e, + const float& c) const + { + e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScaleRelu +{ + CK_TILE_HOST_DEVICE + ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const; + + template <> + CK_TILE_HOST_DEVICE void operator()(ck_tile::fp8_t& e, + const float& c) const + { + float x; + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +template +struct Cast +{ + template + CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const + { + y = ck_tile::type_convert(x); + }; +}; + +// support fastconvert of int8 to fp16 +#if 0 +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + CK_TILE_DEVICE static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + + return Output; + } + + CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + CK_TILE_DEVICE static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; +#endif +} // namespace element_wise +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c4872def1d..05d3dae1cc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); - buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); + buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 // auto q_tile = q; // tile_elementwise_in(q_element_func, q); @@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); - async_load_fence(k_dram_window.get_num_access()); + async_load_fence(k_dram_window.get_num_of_access()); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); gemm_0(s_acc, diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index a01265ad5d..51d55235e8 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -4,9 +4,14 @@ #pragma once #include "ck_tile/core.hpp" +#include namespace ck_tile { +/* + * TODO: block_tile_reduce_sync() currently has a limitation + * Y dim must have at least one dim not been reduced + */ // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, @@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, }); } +/* + * this version is faster, using xor to do reduce, no need broadcast anymore + * TODO: the limitation is to-be-reduced P dim can only mapping to one R dim? + */ +template +CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor, + const ReduceFunc& reduce_func) +{ + using Dstr = typename AccDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size(); + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local = acc_tensor.get_thread_buffer()[i]; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // xor + index_t src_lane = + __lane_id() ^ (number{}.value); + + // pull data from remote lane + const auto v_remote = warp_shuffle(v_local, src_lane); + + // reduce + v_local = reduce_func(v_local, v_remote); + }); + } + }); + + acc_tensor.get_thread_buffer()(i) = v_local; + }); +} + // FIXME: this is for 2D to 1D reduce only, need to support n-D template 1D reduce (reduce-dim=seq<0, 1>) +// this version only support in/acc/out datatypes are the same +// this version will call thread/warp+sync in one function call +// +template +struct BlockReduce2D +{ + using InDistributedTensor = remove_cvref_t; + using InDataType = typename InDistributedTensor::DataType; + + CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_) + : t(t_), reduce_init(reduce_init_) + { + } + + CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const + { + using ReduceDim = sequence<1>; // hard coded + constexpr auto acc_dstr = + make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding( + InDistributedTensor::get_tile_distribution() + .get_static_tile_distribution_encoding(), + ReduceDim{})); + + return make_static_distributed_tensor(acc_dstr); + } + + // return number of pixels each lane need to reduce + CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const + { + constexpr auto spans = InDistributedTensor::get_distributed_spans(); + } + + // Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan + // this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim + // internally + // For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2 + // element , and the first element is always ignored For simplicity, will always try from + // right-to-left to find alone which Y dim to split + template > + CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func, + const ReduceSyncFunc& reduce_sync_func, + ReducePacksPerXDim = {}) const + { + constexpr auto spans = InDistributedTensor::get_distributed_spans(); + + constexpr auto row_y_unpacks = [&]() { + constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{}; + constexpr auto row_y_size = + reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{}); + constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{}); + + static_assert(row_y_size % row_y_packs == 0); + + constexpr auto row_y_slice_size = row_y_size / row_y_packs; + + constexpr auto slice_info = slice_sequence(row_y_lengths, number{}); + constexpr auto unpacks = slice_info[number<1>{}]; + return unpacks; + }(); + + auto acc_tensor = MakeDstBlockTile(); + + // in-thread reduction + // FIXME: hard coded to be 2D to 1D reduction + sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) { + constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0); + + auto acc = acc_tensor[acc_dstr_idx]; + + sweep_tile_uspan( + spans[number<1>{}], + [&](auto... dstr_idx_i1) { + acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...); + }, + row_y_unpacks); + + acc_tensor(acc_dstr_idx) = acc; + }); + + // TODO: always use xor to do cross-lane reduce + block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func); + + return acc_tensor; + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const + { + return operator()(reduce_func, reduce_func); + } + + InDistributedTensor t; + InDataType reduce_init; +}; + +// deduction guide +template +CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D; + } // namespace ck_tile diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp new file mode 100644 index 0000000000..584ca70689 --- /dev/null +++ b/include/ck_tile/ops/softmax.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" +#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp new file mode 100644 index 0000000000..607ec7eb53 --- /dev/null +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +#define _BLOCK_SOFTMAX_USE_UNPACK2 0 + +namespace ck_tile { + +/* +simple 2d softmax implementation, along row (dim=1) +requirement: + 1). each row is within a warp + 2). data type must be a dword +*/ +template +struct BlockSoftmax2D +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = typename Problem::DataType; + + template + CK_TILE_DEVICE void + operator()(const DistributedTensor& x, DistributedTensor& y, number = {}) + { + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; +#if _BLOCK_SOFTMAX_USE_UNPACK2 + const auto f_max3 = [](auto e0, auto e1, auto e2) { + float rtn; + asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2)); + return rtn; + }; + const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; }; +#endif + + // compute row max + auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; +#if _BLOCK_SOFTMAX_USE_UNPACK2 + auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{}); +#else + auto row_max = reduce_row_max(f_max); +#endif + sweep_tile([&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + y(idx) = exp(x[idx] - row_max[row_id]); + }); + + // compute row sum + auto reduce_row_sum = BlockReduce2D{y, DataType{0}}; +#if _BLOCK_SOFTMAX_USE_UNPACK2 + auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{}); +#else + auto row_sum = reduce_row_sum(f_sum); +#endif + // reciprocal + auto r = make_static_distributed_tensor(row_sum.get_tile_distribution()); + sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); }); + + // scale + sweep_tile([&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + y(idx) = y(idx) * r(row_id); + }); + } + + template + CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number = {}) + { + auto y = DistributedTensor{}; // distributed tensor + operator()(x, y, number{}); + return y; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp new file mode 100644 index 0000000000..82b9a5a486 --- /dev/null +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockSoftmax2DProblem +{ + using DataType = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp new file mode 100644 index 0000000000..b1143e4a06 --- /dev/null +++ b/include/ck_tile/ops/topk.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" +#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp b/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp new file mode 100644 index 0000000000..164685f980 --- /dev/null +++ b/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +/* +simple 2d topk implementation, along row (dim=1) +requirement: + 1). each row is within a warp +*/ +template +struct BlockTopkStream2D +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = typename Problem::DataType; + using IndexType = typename Problem::IndexType; + + // TODO: if DataType is subdword, need pack into single dword to use argmax + struct ArgmaxPacket + { + DataType arg; + index_t value; + }; + + template + CK_TILE_DEVICE void operator()(const DistributedTensor& x, + const OutWindow& out_window, + const IdxWindow& idx_window, + index_t k, + number = {}) + { + OutWindow out_window_tmp = out_window; + IdxWindow idx_window_tmp = idx_window; + static_assert( + std::is_same_v && + std::is_same_v); + static_assert(std::is_same_v); + + DistributedTensor x_tmp = x; + constexpr auto dst_dist = typename IdxWindow::TileDstr{}; + + // argmax for topk + const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) { + return e0.arg > e1.arg ? e0 : e1; + }; + + for(index_t i_k = 0; i_k < k; i_k++) + { + constexpr auto span_2d = DistributedTensor::get_distributed_spans(); + auto packet = [&]() { + auto tmp = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + tmp.get_tile_distribution(), make_tuple(idx0, idx1)); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + ArgmaxPacket t; + t.arg = x_tmp(i_j_idx); // !!! we reference x here + t.value = tile_idx.at(number<1>{}); + tmp(i_j_idx) = t; + }); + }); + return tmp; + }(); + + auto argmax_init = ArgmaxPacket{-numeric::infinity(), 0}; + auto r = block_tile_reduce(packet, sequence<1>{}, f_argmax, argmax_init); + block_tile_reduce_xor_sync(r, f_argmax); + + auto o = make_static_distributed_tensor(dst_dist); + auto i = make_static_distributed_tensor(dst_dist); + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + ArgmaxPacket tmp = r(i_j_idx); + o(i_j_idx) = tmp.arg; + i(i_j_idx) = tmp.value; + }); + }); + + // update value + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + x.get_tile_distribution(), make_tuple(idx0, idx1)); + auto col_id = tile_idx.at(number<1>{}); + + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric::infinity() + : x_tmp(i_j_idx); + }); + }); + + if(threadIdx.x % Problem::ColLanes == 0) + { + store_tile(out_window_tmp, o); + store_tile(idx_window_tmp, i); + } + move_tile_window(out_window_tmp, {number<0>{}, number<1>{}}); + move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}}); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp b/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp new file mode 100644 index 0000000000..d47188d862 --- /dev/null +++ b/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +/* +simple 2d topk implementation, along row (dim=1) +requirement: + 1). each row is within a warp +*/ +template +struct BlockTopkStream2DProblem +{ + using DataType = remove_cvref_t; + using IndexType = remove_cvref_t; + static constexpr index_t ColLanes = ColLanes_; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp new file mode 100644 index 0000000000..809473d53b --- /dev/null +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp new file mode 100644 index 0000000000..b8520ae61a --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct TopkSoftmaxHostArgs +{ + const void* p_input; + void* p_output; + void* p_indices; + index_t num_rows; + index_t num_experts; + index_t topk; + index_t stride_input; // row stride for input, at least experts + index_t stride_output; // row stride for output/indices, at least tpok +}; + +template +struct TopkSoftmaxKernel +{ + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; + + using InputType = typename Problem::InputType; + using WeightType = typename Problem::WeightType; + using IndexType = typename Problem::IndexType; + + struct TopkSoftmaxKargs + { + const void* p_input; + void* p_output; + void* p_indices; + index_t num_rows; + index_t num_experts; + index_t topk; + index_t stride_input; // row stride for input, at least experts + index_t stride_output; // row stride for output/indices, at least tpok + }; + + using Kargs = TopkSoftmaxKargs; + using Hargs = TopkSoftmaxHostArgs; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + if constexpr(Problem::LaunchType > 0) + { + int num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return dim3(num_cu * Problem::LaunchType); + } + else + { + const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp; + const int num_blocks = + (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock; + return dim3(num_blocks); + } + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_input = h.p_input; + k.p_output = h.p_output; + k.p_indices = h.p_indices; + k.num_rows = h.num_rows; + k.num_experts = h.num_experts; + k.topk = h.topk; + k.stride_input = h.stride_input; + k.stride_output = h.stride_output; + return k; + } + + CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + index_t block_row_id = static_cast(blockIdx.x * Problem::RowsPerBlock); + + if(block_row_id > kargs.num_rows) + return; + + index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input); + index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output); + index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id); + + const auto input_window = [&]() { + const InputType* p_input = + reinterpret_cast(kargs.p_input) + block_os_inp; + + auto tmp = make_naive_tensor_view( + p_input, + make_tuple(num_rows_rem, kargs.num_experts), + make_tuple(kargs.stride_input, 1), + number{}, + number<1>{}); + + auto view = pad_tensor_view( + tmp, + make_tuple(number{}, number{}), + sequence<0, 1>{}); // out-most dim no need pad(leverage oob) + + return make_tile_window( + view, + make_tuple(number{}, number{}), + {0, 0}); + }(); + + auto output_window = [&]() { + WeightType* p_output = reinterpret_cast(kargs.p_output) + block_os_out; + auto tmp = make_naive_tensor_view( + p_output, + make_tuple(num_rows_rem, kargs.topk), + make_tuple(kargs.stride_output, 1), + number{}, + number<1>{}); + auto view = + pad_tensor_view(tmp, + make_tuple(number{}, number<1>{}), + sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob) + // 2. we loop over topk 1-1, no need padding + return make_tile_window( + view, make_tuple(number{}, number<1>{}), {0, 0}); + }(); + + auto indices_window = [&]() { + IndexType* p_indices = reinterpret_cast(kargs.p_indices) + block_os_out; + auto tmp = make_naive_tensor_view( + p_indices, + make_tuple(num_rows_rem, kargs.topk), + make_tuple(kargs.stride_output, 1), + number{}, + number<1>{}); + auto view = + pad_tensor_view(tmp, + make_tuple(number{}, number<1>{}), + sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob) + // 2. we loop over topk 1-1, no need padding + return make_tile_window( + view, make_tuple(number{}, number<1>{}), {0, 0}); + }(); + + Pipeline{}(input_window, + output_window, + indices_window, + kargs.num_rows, + kargs.num_experts, + kargs.topk, + block_row_id); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp new file mode 100644 index 0000000000..d620d9bec9 --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" +#include +#include + +#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW +#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0 +#endif + +namespace ck_tile { + +template +struct TopkSoftmaxWarpPerRowPipeline +{ + // TODO: this kernel only support warp per row + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using WeightType = typename Problem::WeightType; + + template + CK_TILE_DEVICE auto operator()(const InputWindow& input_window, + OutputWindow& out_window, + IndexWindow& idx_window, + index_t rows, + index_t experts, + index_t k, + index_t block_row_id) + { +#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW + auto inp_win = make_tile_window_linear_raw( + input_window, Policy::template MakeInputDistribution(), sequence<0, 1>{}); +#else + auto inp_win = make_tile_window_linear( + input_window, Policy::template MakeInputDistribution(), sequence<0, 1>{}); +#endif + auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(), + out_window.get_window_lengths(), + out_window.get_window_origin(), + Policy::template MakeOutputDistribution()); + auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(), + idx_window.get_window_lengths(), + idx_window.get_window_origin(), + Policy::template MakeOutputDistribution()); + + auto softmax = Policy::template GetSoftmax(); + auto topk = Policy::template GetTopk(); + + const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock; + + while(1) + { +#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW + __builtin_amdgcn_sched_barrier(0); + auto x = + load_tile_raw(inp_win, number<-1>{}, bool_constant{}, bool_constant{}); + buffer_load_fence(number<0>{}); + __builtin_amdgcn_sched_barrier(0); +#else + auto x = load_tile(inp_win); +#endif + // cast and pad input data + auto w = [&]() { +#if 0 + auto w_ = cast_tile(x); + + constexpr auto span_2d = decltype(w_)::get_distributed_spans(); + sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + w_.get_tile_distribution(), i_j_idx); + const auto current_expert = x_indices.at(number<1>{}); + // set to -INF if OOB so that later softmax can work properly + w_(i_j_idx) = current_expert >= experts ? -numeric::infinity() + : w_(i_j_idx); + }); + }); + return w_; +#else + auto w_ = make_static_distributed_tensor(x.get_tile_distribution()); + auto w_f = [&](auto idx) { + w_(idx) = type_convert(x(idx)); + const auto x_indices = + get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx); + const auto current_expert = x_indices.at(number<1>{}); + w_(idx) = + current_expert >= experts ? -numeric::infinity() : w_(idx); + }; + tile_sweeper ts{w_, w_f}; + ts(); + return w_; +#endif + }(); + + // softmax + auto y = softmax(w); + + topk(y, out_win, idx_win, k); + + // check exit + if constexpr(Problem::LaunchType == 0) + { + break; + } + else + { + block_row_id += grid_rows_per_loop; + if(block_row_id >= rows) + break; + } + + move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}}); + move_tile_window(out_win, {grid_rows_per_loop, number<0>{}}); + move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}}); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp new file mode 100644 index 0000000000..a6e886bd39 --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/softmax.hpp" +#include "ck_tile/ops/topk.hpp" + +namespace ck_tile { + +struct TopkSoftmaxWarpPerRowPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() + { + // TODO: Y dim must have one dim that is not reduced + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, // repeat this one + tuple, + sequence<1>>, // each row write out single element + tuple, sequence<1, 0>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax() + { + using softmax_problem = BlockSoftmax2DProblem; + return BlockSoftmax2D{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTopk() + { + using topk_problem = BlockTopkStream2DProblem; + // Note: replicate is LanesPerRow + return BlockTopkStream2D{}; + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp new file mode 100644 index 0000000000..917096ad5e --- /dev/null +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template 0, persistent #occupancy + index_t BlockSize_ = 256> +struct TopkSoftmaxWarpPerRowProblem +{ + // TODO: this kernel only support warp per row + using InputType = remove_cvref_t; + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t LaunchType = LaunchType_; + static constexpr index_t Experts = Experts_; + static constexpr index_t BytesPerIssue = BytesPerIssue_; + static constexpr index_t IssuesPerCol = IssuesPerCol_; + static constexpr index_t BlockSize = BlockSize_; + static constexpr index_t WarpSize = get_warp_size(); + + static_assert(BytesPerIssue % sizeof(InputType) == 0); + static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType); + static_assert(Experts % VectorSize == 0); + static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize); + static_assert(WarpSize % LanesPerRow == 0); + static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow; + static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue; + static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize); + + static constexpr index_t WarpsPerBlock = BlockSize / WarpSize; + static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock; +}; +} // namespace ck_tile