From 112d521b0939083573f931cba31089c3f31d6ac2 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 3 Mar 2024 23:48:31 +0000 Subject: [PATCH] fix xx --- example/ck_tile/01_fmha/fmha_fwd.cpp | 4 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 15 +- example/ck_tile/01_fmha/generate.py | 2 +- include/ck_tile/core.hpp | 5 +- .../ck_tile/core/arch/amd_address_space.hpp | 20 - .../core/arch/amd_buffer_addressing.hpp | 1034 ++++++++--------- include/ck_tile/core/arch/arch.hpp | 61 + include/ck_tile/core/arch/utility.hpp | 27 + include/ck_tile/core/config.hpp | 12 + include/ck_tile/core/container/array.hpp | 34 +- .../core/container/container_helper.hpp | 17 + include/ck_tile/core/container/sequence.hpp | 5 +- include/ck_tile/core/container/tuple.hpp | 135 ++- include/ck_tile/core/numeric/bfloat16.hpp | 41 +- include/ck_tile/core/numeric/float8.hpp | 95 +- include/ck_tile/core/numeric/half.hpp | 5 +- include/ck_tile/core/numeric/math.hpp | 7 +- include/ck_tile/core/numeric/type_convert.hpp | 49 +- include/ck_tile/core/numeric/vector_type.hpp | 327 ++---- include/ck_tile/core/tensor/buffer_view.hpp | 316 ++--- include/ck_tile/core/tensor/load_tile.hpp | 9 +- .../ck_tile/core/tensor/null_tile_window.hpp | 1 + include/ck_tile/core/tensor/shuffle_tile.hpp | 10 +- include/ck_tile/core/tensor/slice_tile.hpp | 2 +- .../core/tensor/static_distributed_tensor.hpp | 1 + include/ck_tile/core/tensor/store_tile.hpp | 4 +- .../ck_tile/core/tensor/tensor_adaptor.hpp | 16 +- .../ck_tile/core/tensor/tensor_descriptor.hpp | 8 +- include/ck_tile/core/tensor/tensor_view.hpp | 84 +- .../ck_tile/core/tensor/tile_distribution.hpp | 6 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 35 +- include/ck_tile/core/tensor/tile_window.hpp | 64 +- include/ck_tile/core/utility/functional.hpp | 14 + include/ck_tile/core/utility/to_sequence.hpp | 9 +- .../core/utility/transpose_vectors.hpp | 123 ++ include/ck_tile/core/utility/type_convert.hpp | 57 - include/ck_tile/core/utility/type_traits.hpp | 14 + include/ck_tile/host/check_err.hpp | 6 +- include/ck_tile/host/kernel_launch.hpp | 4 +- include/ck_tile/ops/common/README.md | 4 + include/ck_tile/ops/epilogue.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 9 +- .../pipeline/block_fmha_pipeline_problem.hpp | 12 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 13 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 15 +- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 13 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 11 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 66 +- include/ck_tile/ops/gemm.hpp | 1 + .../block/block_gemm_areg_bgmem_creg_v1.hpp | 13 +- .../block/block_gemm_areg_bsmem_creg_v1.hpp | 22 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 12 +- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 15 +- .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 6 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 12 +- ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 4 +- ...lock_gemm_pipeline_agmem_bgmem_creg_v2.hpp | 4 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 2 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 69 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 192 ++- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 17 +- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 12 +- include/ck_tile/ops/reduce.hpp | 1 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 2 +- include/ck_tile/remod.py | 11 + 66 files changed, 1720 insertions(+), 1498 deletions(-) delete mode 100644 include/ck_tile/core/arch/amd_address_space.hpp create mode 100644 include/ck_tile/core/arch/arch.hpp create mode 100644 include/ck_tile/core/arch/utility.hpp create mode 100644 include/ck_tile/core/utility/transpose_vectors.hpp delete mode 100644 include/ck_tile/core/utility/type_convert.hpp create mode 100644 include/ck_tile/ops/common/README.md diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 1b2183960c..9e11f4d19e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -72,7 +72,7 @@ auto get_elimit(int /*init_method*/) } template <> -auto get_elimit(int init_method) +auto get_elimit(int init_method) { if(init_method == 0) { @@ -510,7 +510,7 @@ int main(int argc, char* argv[]) } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "fp8") { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index eb11efb2e2..325ff6b78a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -7,7 +7,6 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/common.hpp" #include "mask.hpp" template @@ -29,18 +28,18 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bhalf_t; - using KDataType = ck_tile::bhalf_t; - using VDataType = ck_tile::bhalf_t; - using BiasDataType = ck_tile::bhalf_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bhalf_t; // data type for A matrix of second gemm + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bhalf_t; + using ODataType = ck_tile::bf16_t; }; template <> diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index b3be008f09..66feae6a5d 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -11,7 +11,7 @@ import copy DTYPE_MAP = { "fp16": "ck_tile::half_t", - "bf16": "ck_tile::bhalf_t", + "bf16": "ck_tile::bf16_t", "fp8" : "ck_tile::fp8_t" } diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 0123163a65..f53a6b0fd6 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -6,8 +6,9 @@ #include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" -#include "ck_tile/core/arch/amd_address_space.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/container_helper.hpp" @@ -51,6 +52,6 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" -#include "ck_tile/core/utility/type_convert.hpp" +#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/arch/amd_address_space.hpp b/include/ck_tile/core/arch/amd_address_space.hpp deleted file mode 100644 index 19a9ded568..0000000000 --- a/include/ck_tile/core/arch/amd_address_space.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -// Address Space for AMDGCN -// https://llvm.org/docs/AMDGPUUsage.html#address-space - -namespace ck_tile { - -enum struct address_space_enum -{ - generic, - global, - lds, - sgpr, - vgpr, -}; - -} // namespace ck_tile diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index cfba73f74d..9a7c95f4c2 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -5,53 +5,27 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" namespace ck_tile { -template -union buffer_resource +// 128 bit SGPRs to supply buffer resource in buffer instructions +// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions +struct __attribute__((packed)) buffer_resource { - CK_TILE_DEVICE constexpr buffer_resource() : content{} {} - - // 128 bit SGPRs to supply buffer resource in buffer instructions - // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions - int32x4_t content; - statically_indexed_array address; - statically_indexed_array range; - statically_indexed_array config; + const void* ptr; + uint32_t range; + uint32_t config; }; -template -CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { - buffer_resource wave_buffer_resource; - - // wavewise base address (64 bit) - wave_buffer_resource.address(number<0>{}) = const_cast*>(p_wave); - // wavewise range (32 bit) - wave_buffer_resource.range(number<2>{}) = element_space_size * sizeof(T); - // wavewise setting (32 bit) - wave_buffer_resource.config(number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; - - return wave_buffer_resource.content; -} - -template -CK_TILE_DEVICE int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) -{ - buffer_resource wave_buffer_resource; - - // wavewise base address (64 bit) - wave_buffer_resource.address(number<0>{}) = const_cast*>(p_wave); - // wavewise range (32 bit) - wave_buffer_resource.range(number<2>{}) = 0xffffffff; // max possible range - // wavewise setting (32 bit) - wave_buffer_resource.config(number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; - - return wave_buffer_resource.content; + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; + return __builtin_bit_cast(int32x4_t, res); } // TODO: glc/slc/... @@ -73,9 +47,9 @@ struct buffer_load<16> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 16); - using mubuf_t = float __attribute__((ext_vector_type(4))); + using mbuf_t = fp32x4_t; asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -93,9 +67,9 @@ struct buffer_load<8> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 8); - using mubuf_t = float __attribute__((ext_vector_type(2))); + using mbuf_t = fp32x2_t; asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -113,9 +87,9 @@ struct buffer_load<4> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mubuf_t = float; + using mbuf_t = float; asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -133,9 +107,9 @@ struct buffer_load<2> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mubuf_t = float; + using mbuf_t = float; asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -153,9 +127,9 @@ struct buffer_load<1> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mubuf_t = float; + using mbuf_t = float; asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -177,13 +151,13 @@ struct buffer_load_if<16> { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mubuf_t = float __attribute__((ext_vector_type(4))); - static_assert(sizeof(mubuf_t) == sizeof(T)); + using mbuf_t = fp32x4_t; + static_assert(sizeof(mbuf_t) == sizeof(T)); asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "memory"); } @@ -202,12 +176,12 @@ struct buffer_load_if<8> { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mubuf_t = float __attribute__((ext_vector_type(2))); + using mbuf_t = fp32x2_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "memory"); } @@ -226,12 +200,12 @@ struct buffer_load_if<4> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mubuf_t = float; + using mbuf_t = float; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "memory"); } @@ -250,12 +224,12 @@ struct buffer_load_if<2> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mubuf_t = float; + using mbuf_t = float; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "memory"); } @@ -274,12 +248,12 @@ struct buffer_load_if<1> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mubuf_t = float; + using mbuf_t = float; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "memory"); } @@ -300,10 +274,12 @@ struct buffer_store<16> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); - asm volatile("buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = fp32x4_t; + asm volatile( + "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); } }; @@ -319,10 +295,12 @@ struct buffer_store<8> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); - asm volatile("buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = fp32x2_t; + asm volatile( + "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); } }; @@ -338,10 +316,12 @@ struct buffer_store<4> index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); - asm volatile("buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = float; + asm volatile( + "buffer_store_dword %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); } }; @@ -356,11 +336,13 @@ struct buffer_store<2> index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { - static_assert(sizeof(T) == 2); - asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile( + "buffer_store_short %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); } }; @@ -375,11 +357,13 @@ struct buffer_store<1> index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { - static_assert(sizeof(T) == 1); - asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile( + "buffer_store_byte %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); } }; @@ -399,11 +383,12 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" : - : "v"(value), + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), @@ -427,11 +412,12 @@ struct buffer_store_if<8> { static_assert(sizeof(T) == 8); auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = fp32x2_t; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" : - : "v"(value), + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), @@ -455,11 +441,12 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" : - : "v"(value), + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), @@ -481,13 +468,14 @@ struct buffer_store_if<2> index_t i_offset /*max 0xFFF*/, index_t flag = 1) { - static_assert(sizeof(T) == 2); + static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" : - : "v"(value), + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), @@ -509,13 +497,14 @@ struct buffer_store_if<1> index_t i_offset /*max 0xFFF*/, index_t flag = 1) { - static_assert(sizeof(T) == 1); + static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" : - : "v"(value), + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), @@ -540,44 +529,44 @@ namespace impl{ // TODO: may have scratch (because this is memory?) // need to reduce extra move inside compiler template -CK_TILE_DEVICE void insert_dummy_dep_per_dword(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) { for (auto i = 0; i < b.size(); i++) asm volatile(" " : : "v"(b.get(i)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)) : "memory"); } template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(static_buffer_c& b) +CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array& b) { asm volatile(" " : : "v"(b.get(0)), "v"(b.get(1)), "v"(b.get(2)), "v"(b.get(3)), "v"(b.get(4)), "v"(b.get(5)), "v"(b.get(6)), "v"(b.get(7)), "v"(b.get(8)), "v"(b.get(9)), "v"(b.get(10)), "v"(b.get(11)), "v"(b.get(12)), "v"(b.get(13)), "v"(b.get(14)), "v"(b.get(15)), @@ -591,7 +580,7 @@ template CK_TILE_DEVICE void insert_dummy_dep(T & buffer) { // TODO: indeed we expect T to be multiple of dword. subdword is always buggy - using da_type = static_buffer_c; + using da_type = array; auto & dummy = reinterpret_cast(buffer); insert_dummy_dep_per_dword(dummy); } @@ -636,19 +625,19 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); // buffer load i16 -CK_TILE_DEVICE bhalf_t +CK_TILE_DEVICE int16_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); -CK_TILE_DEVICE bhalf2_t +CK_TILE_DEVICE int16x2_t llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); -CK_TILE_DEVICE bhalf4_t +CK_TILE_DEVICE int16x4_t llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, index_t voffset, index_t soffset, @@ -674,19 +663,19 @@ llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); // buffer load fp16 -CK_TILE_DEVICE half_t +CK_TILE_DEVICE fp16_t llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); -CK_TILE_DEVICE half2_t +CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); -CK_TILE_DEVICE half4_t +CK_TILE_DEVICE fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t voffset, index_t soffset, @@ -699,13 +688,13 @@ llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); -CK_TILE_DEVICE float2_t +CK_TILE_DEVICE fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); -CK_TILE_DEVICE float4_t +CK_TILE_DEVICE fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t voffset, index_t soffset, @@ -735,21 +724,21 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, // buffer store i16 CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, +llvm_amdgcn_raw_buffer_store_i16(bf16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, +llvm_amdgcn_raw_buffer_store_i16x2(bf16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, +llvm_amdgcn_raw_buffer_store_i16x4(bf16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, @@ -779,21 +768,21 @@ llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, // buffer store fp16 CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, +llvm_amdgcn_raw_buffer_store_fp16(fp16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, +llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, +llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, @@ -808,22 +797,22 @@ llvm_amdgcn_raw_buffer_store_fp32(float vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, +llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); CK_TILE_DEVICE void -llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, +llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); // buffer atomic-add fp16 -CK_TILE_DEVICE half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - half2_t vdata, +CK_TILE_DEVICE fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + fp16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, @@ -886,20 +875,21 @@ enum struct amd_buffer_coherence_enum template -CK_TILE_DEVICE typename vector_type::type -amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +CK_TILE_DEVICE array amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); + using rtn_type = array; + if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 2) { @@ -909,7 +899,7 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } else if constexpr(N == 4) { @@ -918,7 +908,7 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } else if constexpr(N == 8) { @@ -927,7 +917,7 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } else if constexpr(N == 16) { @@ -935,7 +925,7 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } else if constexpr(N == 32) { @@ -948,12 +938,12 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); - vector_type tmp; + array tmp; - tmp.AsType()(number<0>{}) = tmp0; - tmp.AsType()(number<1>{}) = tmp1; + tmp.template get_as()(number<0>{}) = tmp0; + tmp.template get_as()(number<1>{}) = tmp1; - return bit_cast(tmp); + return bit_cast(tmp); } else if constexpr(N == 64) { @@ -977,14 +967,14 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, src_wave_addr_offset + 12 * sizeof(int32_t), static_cast(coherence)); - vector_type tmp; + array tmp; - tmp.AsType()(number<0>{}) = tmp0; - tmp.AsType()(number<1>{}) = tmp1; - tmp.AsType()(number<2>{}) = tmp2; - tmp.AsType()(number<3>{}) = tmp3; + tmp.template get_as()(number<0>{}) = tmp0; + tmp.template get_as()(number<1>{}) = tmp1; + tmp.template get_as()(number<2>{}) = tmp2; + tmp.template get_as()(number<3>{}) = tmp3; - return bit_cast(tmp); + return bit_cast(tmp); } } @@ -995,150 +985,161 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, template -CK_TILE_DEVICE typename vector_type::type -amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +CK_TILE_DEVICE array amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) // fp32 + using rtn_type = array; + + if constexpr(std::is_same::value) // fp32 { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 8) { - vector_type tmp; + array tmp; - tmp.AsType()(number<0>{}) = + tmp.template get_as()(number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - tmp.AsType()(number<1>{}) = + tmp.template get_as()(number<1>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), static_cast(coherence)); - return tmp.AsType()(number<0>{}); + return tmp; } else if constexpr(N == 16) { - vector_type tmp; + array tmp; - tmp.AsType()(number<0>{}) = + tmp.template get_as()(number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - tmp.AsType()(number<1>{}) = + tmp.template get_as()(number<1>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), static_cast(coherence)); - tmp.AsType()(number<2>{}) = + tmp.template get_as()(number<2>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(float), static_cast(coherence)); - tmp.AsType()(number<3>{}) = + tmp.template get_as()(number<3>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 12 * sizeof(float), static_cast(coherence)); - return tmp.AsType()(number<0>{}); + return tmp; } } - else if constexpr(is_same::value) // fp16 + else if constexpr(std::is_same::value) // fp16 { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 8) { // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } } - else if constexpr(is_same::value) // bf16 + else if constexpr(std::is_same::value) // bf16 { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); } else if constexpr(N == 8) { @@ -1147,17 +1148,15 @@ amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); + return bit_cast(tmp); } } else // other datatype { - using r_t = typename vector_type::type; - - auto raw_data = amd_buffer_load_impl_raw( + auto raw_data = amd_buffer_load_impl_with_bytes( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); - return bit_cast(raw_data); + return bit_cast(raw_data); } } @@ -1165,7 +1164,7 @@ template -CK_TILE_DEVICE void amd_buffer_load_raw_impl(typename vector_type::type& dst, +CK_TILE_DEVICE void amd_buffer_load_raw_impl(array& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, @@ -1175,7 +1174,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(typename vector_type::type& d static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, "wrong! not supported by buffer_load instruction"); - using type = typename vector_type::type; + using type = array; if constexpr(oob_conditional_check) { buffer_load_if{}( @@ -1208,18 +1207,17 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, template -CK_TILE_DEVICE void -amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const array src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + llvm_amdgcn_raw_buffer_store_i8(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1260,74 +1258,77 @@ amd_buffer_store_impl_raw(const typename vector_type::type src_thread } else if constexpr(N == 32) { - vector_type tmp{bit_cast(src_thread_data)}; + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); } else if constexpr(N == 64) { - vector_type tmp{bit_cast(src_thread_data)}; + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 8, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 12, - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); } } template -CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type src_thread_data, +CK_TILE_DEVICE void amd_buffer_store_impl(const array src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) // fp32 + if constexpr(std::is_same::value) // fp32 { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1335,7 +1336,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1343,7 +1344,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 4) { - llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1351,24 +1352,25 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 8) { - vector_type tmp{src_thread_data}; - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_fp32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_fp32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); } } - else if constexpr(is_same::value) // fp16 + else if constexpr(std::is_same::value) // fp16 { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1376,7 +1378,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1384,7 +1386,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 4) { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1393,21 +1395,21 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type else if constexpr(N == 8) { #if 0 - vector_type tmp{src_thread_data}; + array tmp{src_thread_data}; - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[number<0>{}], + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[number<1>{}], + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), + dst_wave_addr_offset + 4 * sizeof(fp16_t), static_cast(coherence)); #else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1415,11 +1417,11 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type #endif } } - else if constexpr(is_same::value) // bf16 + else if constexpr(std::is_same::value) // bf16 { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1427,7 +1429,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, + llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1435,7 +1437,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 4) { - llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, + llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1443,29 +1445,29 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const typename vector_type::type } else if constexpr(N == 8) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_store_i16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bhalf_t), - static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_i16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bf16_t), + static_cast(coherence)); } } else { - using r_t = typename vector_type::type; + using r_t = array; - amd_buffer_store_impl_raw(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset); + amd_buffer_store_impl_with_bytes(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); } } @@ -1473,18 +1475,17 @@ template -CK_TILE_DEVICE void -amd_buffer_store_raw_impl(const typename vector_type::type dst_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset, - index_t is_valid_element = 1) +CK_TILE_DEVICE void amd_buffer_store_raw_impl(const array& dst_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset, + index_t is_valid_element = 1) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, "wrong! not supported by buffer_store instruction"); - using type = typename vector_type::type; + using type = array; if constexpr(oob_conditional_check) { buffer_store_if{}(dst_thread_data, @@ -1505,22 +1506,21 @@ amd_buffer_store_raw_impl(const typename vector_type::type dst_thread_data } template -CK_TILE_DEVICE void -amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const array& src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { - static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4)), + static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)) || + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(std::is_same::value) { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1528,54 +1528,56 @@ amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_dat } else if constexpr(N == 2) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); } else if constexpr(N == 4) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(float), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(float), - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); } } - else if constexpr(is_same::value) + else if constexpr(std::is_same::value) { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1583,34 +1585,32 @@ amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_dat } else if constexpr(N == 4) { - vector_type tmp{src_thread_data}; - static_for<0, 2, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(half2_t), - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + src_thread_data.template get_as()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(fp16x2_t), + 0); }); } else if constexpr(N == 8) { - vector_type tmp{src_thread_data}; - static_for<0, 4, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(half2_t), - 0); + llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + src_thread_data.template get_as()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(fp16x2_t), + 0); }); } } - else if constexpr(is_same::value) + else if constexpr(std::is_same::value) { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1618,65 +1618,66 @@ amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_dat } else if constexpr(N == 2) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); } else if constexpr(N == 4) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(int32_t), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(int32_t), - 0); + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); } } } template -CK_TILE_DEVICE void -amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const array src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { - static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(std::is_same::value) { if constexpr(N == 1) { - llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1684,47 +1685,49 @@ amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_dat } else if constexpr(N == 2) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); } else if constexpr(N == 4) { - vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(double), - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(double), - 0); + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); } } } @@ -1738,22 +1741,17 @@ template -CK_TILE_DEVICE typename vector_type_maker::type::type +CK_TILE_DEVICE array amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size); + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - - constexpr index_t vector_size = scalar_type::vector_size; - #if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = [&]() { if constexpr(oob_conditional_check) @@ -1761,13 +1759,13 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, else return 0; }(); - return amd_buffer_load_impl( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + array tmp = + amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : vector_t(0); + return src_thread_element_valid ? tmp : array{0}; else return tmp; #endif @@ -1781,7 +1779,7 @@ template -CK_TILE_DEVICE typename vector_type_maker::type::type +CK_TILE_DEVICE array amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, @@ -1789,20 +1787,15 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, T customized_value) { const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size); + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - - constexpr index_t vector_size = scalar_type::vector_size; - - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + array tmp = + amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : vector_t(customized_value); + return src_thread_element_valid ? tmp : array{customized_value}; else return tmp; } @@ -1811,23 +1804,18 @@ template -CK_TILE_DEVICE void amd_buffer_load_raw(typename vector_type_maker::type::type& dst, +CK_TILE_DEVICE void amd_buffer_load_raw(array& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, index_t is_valid_element = 0) { const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size); + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - - constexpr index_t vector_size = scalar_type::vector_size; - - amd_buffer_load_raw_impl( + amd_buffer_load_raw_impl( dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); } @@ -1844,7 +1832,7 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size); + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); @@ -1860,22 +1848,17 @@ template -CK_TILE_DEVICE void -amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) +CK_TILE_DEVICE void amd_buffer_store(const array& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - constexpr index_t vector_size = scalar_type::vector_size; - #if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = [&]() { if constexpr(oob_conditional_check) @@ -1883,20 +1866,20 @@ amd_buffer_store(const typename vector_type_maker::type::type src_thread_d else return 0; }(); - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if constexpr(oob_conditional_check) { if(dst_thread_element_valid) { - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } } else { - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif @@ -1906,28 +1889,22 @@ template -CK_TILE_DEVICE void -amd_buffer_store_raw(const typename vector_type_maker::type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) +CK_TILE_DEVICE void amd_buffer_store_raw(const array& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - constexpr index_t vector_size = scalar_type::vector_size; - - amd_buffer_store_raw_impl( - src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_thread_element_valid); + amd_buffer_store_raw_impl(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_thread_element_valid); } // buffer_atomic_add requires: @@ -1935,31 +1912,26 @@ amd_buffer_store_raw(const typename vector_type_maker::type::type src_thre // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template -CK_TILE_DEVICE void -amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) +CK_TILE_DEVICE void amd_buffer_atomic_add(const array& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - constexpr index_t vector_size = scalar_type::vector_size; - #if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_atomic_add_impl( + amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - amd_buffer_atomic_add_impl( + amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif @@ -1970,31 +1942,26 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template -CK_TILE_DEVICE void -amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) +CK_TILE_DEVICE void amd_buffer_atomic_max(const array& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; - constexpr index_t vector_size = scalar_type::vector_size; - #if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_atomic_max_impl( + amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - amd_buffer_atomic_max_impl( + amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif @@ -2025,7 +1992,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); - const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + const int32x4_t src_resource = + make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T)); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp new file mode 100644 index 0000000000..333168fd2a --- /dev/null +++ b/include/ck_tile/core/arch/arch.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" + +namespace ck_tile { + +enum struct address_space_enum +{ + generic, + global, + lds, + sgpr, + vgpr, +}; + +enum struct memory_operation_enum +{ + set, + atomic_add, + atomic_max, + add +}; + +CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} + +CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; } + +CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; } + +// TODO: deprecate these +CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; } + +CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } + +// Use these instead +CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } + +CK_TILE_DEVICE index_t get_warp_id() +{ + return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); +} + +CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } + +CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp new file mode 100644 index 0000000000..1ab2ba1002 --- /dev/null +++ b/include/ck_tile/core/arch/utility.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" + +namespace ck_tile { + +// TODO: we have "memory" clobber here because this inline asm is used for async copy +CK_TILE_DEVICE void m0_set_with_memory(index_t v) +{ + asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory"); +} + +// NOTE: this is an immediate value +CK_TILE_DEVICE void m0_inc_with_memory(index_t v) +{ + asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory"); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 55cca88c8c..b655ae0a6c 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -120,3 +120,15 @@ #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif + +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD -1 +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx1030__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 +#endif diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index fd910abac9..7752f31375 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -44,6 +44,19 @@ struct array data[i] = vlast; } } + CK_TILE_HOST_DEVICE explicit constexpr array(value_type c) + { + for(auto i = 0; i < size(); i++) + data[i] = c; + } + template + CK_TILE_HOST_DEVICE constexpr array(const ArrayType& o) + { + static_assert(ArrayType::size() == size(), "wrong! size not the same"); + for(auto i = 0; i < size(); i++) + data[i] = o.data[i]; + } + CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; } @@ -67,18 +80,18 @@ struct array CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return data[i]; } CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return data[i]; } CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return data[i]; } // TODO: compatible - - template - CK_TILE_HOST_DEVICE constexpr auto operator=(const T& a) +#if 0 + template + CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayType& a) { - static_assert(T::size() == size(), "wrong! size not the same"); + static_assert(ArrayType::size() == size(), "wrong! size not the same"); for(index_t i = 0; i < size(); ++i) { data[i] = a[i]; } return *this; } - +#endif // type punning (strict aliasing) member functions for read/write // aliasing this array of type "T", "N" elements // as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements @@ -122,6 +135,17 @@ struct array CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } }; +template +struct vector_traits; + +// specialization for array +template +struct vector_traits> +{ + using scalar_type = T; + static constexpr index_t vector_size = N; +}; + template CK_TILE_HOST_DEVICE constexpr auto make_array(T&& x, Ts&&... xs) { diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index 88405f6fcb..eec15d2538 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -468,6 +468,7 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) number{}); } +#if 0 #define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ [a_of_b_impl, a_size, bs_sizes] { \ return ck_tile::generate_tuple( \ @@ -479,5 +480,21 @@ CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) }, \ ck_tile::number{}); \ }() +#else +// constexpr index_t can't be captured "-Wunused-lambda-capture" +// TODO: this is ugly +#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ + [a_of_b_impl, bs_sizes] { \ + return ck_tile::generate_tuple( \ + [=](auto i) { \ + constexpr auto b_impl = a_of_b_impl[i]; \ + constexpr index_t b_size = bs_sizes[i]; \ + constexpr auto b = TO_SEQUENCE(b_impl, b_size); \ + return b; \ + }, \ + ck_tile::number{}); \ + }() +#endif + } // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 15313b1b65..581e3b8d61 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -67,13 +67,12 @@ struct sequence CK_TILE_HOST_DEVICE static constexpr auto get(number) { static_assert(I < size(), "wrong! I too large"); - return number{}; + return number()>{}; } CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 - static_assert(I < size(), "wrong! I too large"); const index_t mData[size() + 1] = {Is..., 0}; return mData[I]; } @@ -89,7 +88,7 @@ struct sequence CK_TILE_HOST_DEVICE static constexpr auto at(number) { static_assert(I < size(), "wrong! I too large"); - return number{}; + return number()>{}; } template diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index f95ddb5435..a47cf94811 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -16,40 +16,48 @@ namespace ck_tile { namespace impl { +// the place where content is stored template > -struct tuple_element +struct tuple_object { }; template -struct tuple_element +struct tuple_object { - CK_TILE_HOST_DEVICE constexpr tuple_element() {} - CK_TILE_HOST_DEVICE constexpr tuple_element(const T&) {} + CK_TILE_HOST_DEVICE constexpr tuple_object() {} + CK_TILE_HOST_DEVICE constexpr tuple_object(const T&) {} }; template -struct tuple_element +struct tuple_object { - CK_TILE_HOST_DEVICE constexpr tuple_element() {} - CK_TILE_HOST_DEVICE constexpr tuple_element(const T& e) : element(e) {} + CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {} + CK_TILE_HOST_DEVICE constexpr tuple_object(const T& e) : element(e) {} T element; }; +// NOTE: we return a instance(not a reference) if content is empty template -CK_TILE_HOST_DEVICE constexpr T const& getv(tuple_element const& x) +CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) +{ + return {}; +} + +template +CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) { return x.element; } template -CK_TILE_HOST_DEVICE constexpr T& getv(tuple_element& x) +CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) { return x.element; } template -CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_element&& x) +CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object&& x) { return static_cast(x.element); } @@ -58,18 +66,18 @@ template struct tuple_base; template -struct tuple_base, T...> : public tuple_element... +struct tuple_base, T...> : tuple_object... { CK_TILE_HOST_DEVICE constexpr tuple_base() {} template - CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U const&... u) : tuple_element(u)... + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object(u)... { } template - CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...> const& u) - : tuple_element(getv(static_cast const&>(u)))... + CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base, U...>& u) + : tuple_object(getv(static_cast&>(u)))... { } }; @@ -84,15 +92,13 @@ struct tuple : impl::tuple_base, T...> CK_TILE_HOST_DEVICE constexpr tuple() {} template - CK_TILE_HOST_DEVICE constexpr tuple(U const&... u) - : impl::tuple_base, T...>(u...) + CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...) { } template - CK_TILE_HOST_DEVICE constexpr tuple(tuple const& u) - : impl::tuple_base, T...>( - static_cast const&>(u)) + CK_TILE_HOST_DEVICE constexpr tuple(const tuple& u) + : base(static_cast, U...>&>(u)) { } @@ -109,19 +115,19 @@ struct tuple : impl::tuple_base, T...> #define TP_COM_() static_assert(I < size(), "wrong! out of range") // clang-format off - template CK_TILE_HOST_DEVICE constexpr const auto & get() const { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr const auto & get(number) const { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr auto & get() { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr auto & get(number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr const auto & at() const { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr const auto & at(number) const { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr auto & at() { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr auto & at(number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr auto & operator[](number) { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr const auto & operator[](number) const { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr auto & operator()(number) { TP_COM_(); return get(); } // TODO: compatible + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number) { TP_COM_(); return get(); } // TODO: compatible // clang-format on #undef TP_COM_ }; @@ -250,15 +256,15 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, // By default unroll to the flatten template -CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& element) +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t) { - return element; + return t; } template -CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& element) +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t) { - return make_tuple(element); + return make_tuple(t); } template @@ -334,7 +340,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple t_of_s) index_t max_n1_ = 0; static_for<0, n0, 1>{}([&](auto i0) { - constexpr index_t n1 = t_of_s[i0].size()(); + constexpr index_t n1 = t_of_s[i0].size(); max_n1_ = max_n1_ < n1 ? n1 : max_n1_; }); @@ -345,7 +351,7 @@ CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple t_of_s) array, n0> a_of_a{{-1}}; static_for<0, n0, 1>{}([&](auto i0) { - constexpr index_t n1 = t_of_s[i0].size()(); + constexpr index_t n1 = t_of_s[i0].size(); static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; }); }); @@ -482,3 +488,60 @@ struct tuple_element> }; } // namespace std + +#if 1 +#define TO_TUPLE_OF_NUMBER(a, n) \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \ + [a](ck_tile::sequence) \ + { \ + return ck_tile::tuple{}]>...>{}; \ + } \ + (ck_tile::make_index_sequence{}) \ + _Pragma("clang diagnostic pop") +#else +#define TO_TUPLE_OF_NUMBER(arr, n_) \ + [&arr, n_] { \ + static_assert(arr.size() >= n_, "wrong! out of bound"); \ + \ + static_assert(n_ < 7, "not implemented"); \ + \ + if constexpr(n_ == 0) \ + { \ + return ck_tile::tuple<>{}; \ + } \ + else if constexpr(n_ == 1) \ + { \ + return ck_tile::tuple>{}; \ + } \ + else if constexpr(n_ == 2) \ + { \ + return ck_tile::tuple, number>{}; \ + } \ + else if constexpr(n_ == 3) \ + { \ + return ck_tile::tuple, number, number>{}; \ + } \ + else if constexpr(n_ == 4) \ + { \ + return ck_tile::tuple, number, number, number>{}; \ + } \ + else if constexpr(n_ == 5) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ + else if constexpr(n_ == 6) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ + }() +#endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index b19aefc928..6fd433f005 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -20,7 +20,8 @@ enum class bf16_rounding_mode truncate, }; -template +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}); CK_TILE_HOST_DEVICE @@ -41,36 +42,42 @@ struct alignas(2) bfloat16_t } // constructor - bfloat16_t() = default; + constexpr bfloat16_t() : data() {} // construct from float CK_TILE_HOST_DEVICE - explicit bfloat16_t(const float& x) { data = float_to_bf16_raw(x); } + explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {} // construct from int CK_TILE_HOST_DEVICE - explicit bfloat16_t(const int& x) { data = float_to_bf16_raw(static_cast(x)); } + explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast(x))) {} // construct from unsigned int CK_TILE_HOST_DEVICE - explicit bfloat16_t(const unsigned int& x) { data = float_to_bf16_raw(static_cast(x)); } + explicit constexpr bfloat16_t(const unsigned int& x) + : data(float_to_bf16_raw(static_cast(x))) + { + } // cast to float CK_TILE_HOST_DEVICE - explicit operator float() const { return bf16_to_float_raw(data); } + explicit constexpr operator float() const { return bf16_to_float_raw(data); } // cast to int CK_TILE_HOST_DEVICE - explicit operator int() const { return static_cast(bf16_to_float_raw(data)); } + explicit constexpr operator int() const { return static_cast(bf16_to_float_raw(data)); } // internal access CK_TILE_HOST_DEVICE - raw_type& get() { return data; } + constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE - raw_type get() const { return data; } + constexpr raw_type get() const { return data; } }; +using bf16_t = bfloat16_t; +using bf16_raw_t = typename bf16_t::raw_type; + // round to nearest CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f) @@ -139,8 +146,8 @@ uint16_t float_to_bf16_truc_raw(float f) return uint16_t(u.int32 >> 16); } -template -CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}) +template +CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant) { if constexpr(rounding == bf16_rounding_mode::standard) return float_to_bf16_rtn_raw(f); @@ -161,8 +168,9 @@ float bf16_to_float_raw(uint16_t x) return u.fp32; } -template -CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant = {}) +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant) { return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant{})); } @@ -170,14 +178,15 @@ CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant = {}) CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x) { return static_cast(x); } -template +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant = {}) { return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast(f), constant{})); } CK_TILE_HOST_DEVICE -float bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast(x)); } +half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast(x)); } template struct numeric_limits; @@ -259,6 +268,4 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast CK_TILE_DEVICE bfloat16_t log(bfloat16_t x) { return static_cast(__logf(static_cast(x))); }; -using bf16_t = bfloat16_t; - } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 8ff6f06a19..11e971661b 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/utility/limits.hpp" #include @@ -75,20 +76,19 @@ struct alignas(1) float8_e4m3_t // construct from float CK_TILE_HOST_DEVICE - explicit constexpr float8_e4m3_t(const float& x) { data = float_to_fp8_raw(x); } + explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {} // construct from int CK_TILE_HOST_DEVICE - explicit constexpr float8_e4m3_t(const int& x) + explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast(x))) { - data = float_to_fp8_raw(static_cast(x)); } // construct from unsigned int CK_TILE_HOST_DEVICE explicit constexpr float8_e4m3_t(const unsigned int& x) + : data(float_to_fp8_raw(static_cast(x))) { - data = float_to_fp8_raw(static_cast(x)); } // cast to float @@ -106,6 +106,8 @@ struct alignas(1) float8_e4m3_t CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } }; +using fp8_t = float8_e4m3_t; +using fp8_raw_t = typename fp8_t::raw_type; struct alignas(1) float8_e5m2_t { @@ -132,25 +134,24 @@ struct alignas(1) float8_e5m2_t // construct from float CK_TILE_HOST_DEVICE - explicit constexpr float8_e5m2_t(const float& x) { data = float_to_bf8_raw(x); } + explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {} // construct from int CK_TILE_HOST_DEVICE - explicit constexpr float8_e5m2_t(const int& x) + explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast(x))) { - data = float_to_bf8_raw(static_cast(x)); } // construct from unsigned int CK_TILE_HOST_DEVICE explicit constexpr float8_e5m2_t(const unsigned int& x) + : data(float_to_bf8_raw(static_cast(x))) { - data = float_to_bf8_raw(static_cast(x)); } // cast to float CK_TILE_HOST_DEVICE - explicit constexpr constexpr operator float() const { return bf8_to_float_raw(data); } + explicit constexpr operator float() const { return bf8_to_float_raw(data); } // cast to int CK_TILE_HOST_DEVICE @@ -163,6 +164,8 @@ struct alignas(1) float8_e5m2_t CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } }; +using bf8_t = float8_e5m2_t; +using bf8_raw_t = typename bf8_t::raw_type; // below is sw fp8 conversion, not utilizing hw instruction namespace impl { @@ -431,10 +434,10 @@ CK_TILE_HOST_DEVICE Y cast_from_f8(X x) } } // namespace impl -CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x) +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) { constexpr int seed = 42; - uint32_t rng = prand_generator{}(reinterpret_cast(&x), x); + uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); @@ -453,16 +456,18 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_sr_raw(float x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; - return impl:: - cast_to_f8( - x, rng); + return bit_cast(impl::cast_to_f8(x, rng)); #endif } -CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x) +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) { constexpr int seed = 42; - uint32_t rng = prand_generator{}(reinterpret_cast(&x), x); + uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) union { @@ -479,13 +484,15 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_sr_raw(float x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; - return impl:: - cast_to_f8( - x, rng); + return bit_cast(impl::cast_to_f8(x, rng)); #endif } -CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x) +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float max_fp8 = 240.0f; @@ -506,12 +513,14 @@ CK_TILE_HOST_DEVICE uint8_t float_to_fp8_rtn_raw(float x) constexpr bool clip = true; constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; constexpr uint32_t rng = 0; - return impl:: - cast_to_f8( - x, rng); + return bit_cast(impl::cast_to_f8(x, rng)); #endif } -CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x) +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) union @@ -530,30 +539,32 @@ CK_TILE_HOST_DEVICE uint8_t float_to_bf8_rtn_raw(float x) constexpr bool clip = true; constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; constexpr uint32_t rng = 0; - return impl:: - cast_to_f8( - x, rng); + return bit_cast(impl::cast_to_f8(x, rng)); #endif } // clang-format off -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float x, constant = {}) +template +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); - else return uint8_t{0}; + else return fp8_raw_t{0}; } -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float x, constant = {}) +template +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); - else return uint8_t{0}; + else return bf8_raw_t{0}; } -CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x) +CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -563,11 +574,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t x) return fval; #else constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(x); + return impl::cast_from_f8(fp8_t::bit_cast(x)); #endif } -CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x) +CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -577,18 +588,18 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t x) return fval; #else constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(x); + return impl::cast_from_f8(bf8_t::bit_cast(x)); #endif } -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant = {}) +template +CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant) { return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant{})); } -template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant = {}) +template +CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant) { return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant{})); } @@ -604,8 +615,6 @@ CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x) } // clang-format on -using fp8_t = float8_e4m3_t; -using bf8_t = float8_e5m2_t; template struct numeric_utils; diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 02cf05a7d1..b22f71c045 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -76,6 +76,9 @@ struct alignas(2) half_t constexpr raw_type get() const { return data; } }; +using fp16_t = half_t; +using fp16_raw_t = typename fp16_t::raw_type; + // conversions CK_TILE_HOST_DEVICE float fp16_to_float_hip(const fp16_hip_t& x) @@ -282,6 +285,4 @@ half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))) CK_TILE_DEVICE half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; -using fp16_t = half_t; - } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index aa7c96b6e6..9615b979d5 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -147,6 +147,9 @@ CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& return min(max(x, lowerbound), upperbound); } +CK_TILE_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } +CK_TILE_DEVICE inline int clz(uint32_t x) { return __clz(x); } + // greatest common divisor, aka highest common factor CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y) { @@ -222,7 +225,7 @@ struct less CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x) { // TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail - return 1 << (32 - __builtin_clz(x - 1)); + return 1 << (32 - clz(x - 1)); } template @@ -243,7 +246,7 @@ CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x) { // TODO: x need to be 1 ~ 0x7fffffff // __builtin_clz will produce unexpected result if x is 0; - return 31 - __builtin_clz(x); + return 31 - clz(x); } CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x) diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 5eac399bf7..d64f3f4349 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -3,10 +3,22 @@ #pragma once -#include "ck_tile/core/config.hpp" +#include +#include #include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr remove_cvref_t type_convert(const X& x) +{ + return static_cast(x); +} + #if 0 // Convert X to Y, both X and Y are non-const data types. template && !std::is_reference_v); - return static_cast(x); } -// TODO: const version never called, we may never need // Convert X to Y, either X or Y is a const data type. template && !std::is_reference_v); - using NonConstY = std::remove_const_t; - using NonConstX = std::remove_const_t; - return static_cast(type_convert(x)); + using non_const_y = std::remove_const_t; + using non_const_x = std::remove_const_t; + return static_cast(type_convert(x)); } -#else -// compatible way to call conversion operator and constructor of each custom data type -template -CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - return static_cast(x); -} +#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ + template <> \ + inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return stype_##_to_##dtype_(x); \ + } + +CK_TILE_TYPE_CONVERT(float, fp16_t) +CK_TILE_TYPE_CONVERT(float, bf16_t) +CK_TILE_TYPE_CONVERT(float, fp8_t) +CK_TILE_TYPE_CONVERT(float, bf8_t) + +CK_TILE_TYPE_CONVERT(fp16_t, float) +CK_TILE_TYPE_CONVERT(bf16_t, float) +CK_TILE_TYPE_CONVERT(fp8_t, float) +CK_TILE_TYPE_CONVERT(bf8_t, float) + +#undef CK_TILE_TYPE_CONVERT #endif + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 56baba0567..f20891296b 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -10,295 +10,112 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { -// TODO: the whole content of this file should consider deprecated! +// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay +// basic type to construct a ext_vector_type you must be very careful using this, or will have lot +// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will +// have compiler error +namespace impl { template -struct vector_type +struct ext_vector { static constexpr index_t N = N_; using value_type = T_; using type = value_type __attribute__((ext_vector_type(N))); // this is danguous - - CK_HOST_DEVICE constexpr vector_type() - { - for(auto i = 0; i < N; i++) - data[i] = static_cast(0); - } - CK_HOST_DEVICE constexpr vector_type(type v) - { - auto& r = reinterpret_cast&>(v); - for(auto i = 0; i < N; i++) - data[i] = r.get(i); - } - - value_type data[N]; - CK_HOST_DEVICE static constexpr auto size() { return N; } - CK_HOST_DEVICE auto& get() { return data; } - CK_HOST_DEVICE const auto& get() const { return data; } - CK_HOST_DEVICE auto& get(index_t i) { return data[i]; } - CK_HOST_DEVICE const auto& get(index_t i) const { return data[i]; } - - template - CK_HOST_DEVICE auto& operator[](number) - { - return data[I]; - } - template - CK_HOST_DEVICE const auto& operator[](number) const - { - return data[I]; - } - template - CK_HOST_DEVICE auto& operator()(number) - { - return data[I]; - } - - CK_HOST_DEVICE auto& at(index_t i) { return data[i]; } - CK_HOST_DEVICE const auto& at(index_t i) const { return data[i]; } - template - CK_HOST_DEVICE auto& at() - { - return data[I]; - } - template - CK_HOST_DEVICE const auto& at() const - { - return data[I]; - } - template - CK_HOST_DEVICE auto& at(number) - { - return data[I]; - } - template - CK_HOST_DEVICE const auto& at(number) const - { - return data[I]; - } - -#define _VT_COMMON_AS() \ - static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ - constexpr int vx = sizeof(value_type) * N / sizeof(Tx) - - template - CK_HOST_DEVICE auto& get_as() - { - _VT_COMMON_AS(); - return reinterpret_cast&>(data); - } - template - CK_HOST_DEVICE const auto& get_as() const - { - _VT_COMMON_AS(); - return reinterpret_cast&>(data); - } - template - CK_HOST_DEVICE auto& get_as(index_t i) - { - _VT_COMMON_AS(); - return reinterpret_cast&>(data).get(i); - } - template - CK_HOST_DEVICE const auto& get_as(index_t i) const - { - _VT_COMMON_AS(); - return reinterpret_cast&>(data).get(i); - } -#undef _VT_COMMON_AS }; +} // namespace impl template -struct vector_type_maker +using ext_vector_t = typename impl::ext_vector::type; + +// by default, any type will result in a vector_size=1 with scalar_type=T traits. +// ... unless we have other vector_traits specialization +template +struct vector_traits { - using type = vector_type; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; + using scalar_type = remove_cvref_t; + static constexpr index_t vector_size = 1; }; +// specialization for ext_vector_type() template -using vector_type_maker_t = typename vector_type_maker::type; - -template -CK_HOST_DEVICE constexpr auto make_vector_type(number) +struct vector_traits { - return typename vector_type_maker::type{}; -} - -// scalar_type -template -struct scalar_type; - -// is_scalar_type -template -struct is_scalar_type -{ - static constexpr bool value = (scalar_type>::vector_size == 1); -}; - -// has_same_scalar_type -template -using has_same_scalar_type = is_same>::type, - typename scalar_type>::type>; - -template -struct scalar_type -{ - using type = T; + using scalar_type = T; static constexpr index_t vector_size = N; }; -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - -// -template <> -struct scalar_type -{ - using type = double; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = float; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = half_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = bhalf_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = int64_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = int32_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = int8_t; - static constexpr index_t vector_size = 1; -}; - -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -template <> -struct scalar_type -{ - using type = int4_t; - static constexpr index_t vector_size = 1; -}; -#endif - -template <> -struct scalar_type -{ - using type = fp8_t; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = bf8_t; - static constexpr index_t vector_size = 1; -}; - // below are some pre-defines of ext_vector_type +// attention! 2 vector type could be just the same type // fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; +using fp64x2_t = double __attribute__((ext_vector_type(2))); +using fp64x4_t = double __attribute__((ext_vector_type(4))); // fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; +using fp32x2_t = float __attribute__((ext_vector_type(2))); +using fp32x4_t = float __attribute__((ext_vector_type(4))); +using fp32x8_t = float __attribute__((ext_vector_type(8))); +using fp32x16_t = float __attribute__((ext_vector_type(16))); +using fp32x32_t = float __attribute__((ext_vector_type(32))); +using fp32x64_t = float __attribute__((ext_vector_type(64))); // fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; +using fp16x2_t = fp16_raw_t __attribute__((ext_vector_type(2))); +using fp16x4_t = fp16_raw_t __attribute__((ext_vector_type(4))); +using fp16x8_t = fp16_raw_t __attribute__((ext_vector_type(8))); +using fp16x16_t = fp16_raw_t __attribute__((ext_vector_type(16))); +using fp16x32_t = fp16_raw_t __attribute__((ext_vector_type(32))); +using fp16x64_t = fp16_raw_t __attribute__((ext_vector_type(64))); // bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); // i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; +using int32x2_t = int32_t __attribute__((ext_vector_type(2))); +using int32x4_t = int32_t __attribute__((ext_vector_type(4))); +using int32x8_t = int32_t __attribute__((ext_vector_type(8))); +using int32x16_t = int32_t __attribute__((ext_vector_type(16))); +using int32x32_t = int32_t __attribute__((ext_vector_type(32))); +using int32x64_t = int32_t __attribute__((ext_vector_type(64))); + +// i16 +using int16x2_t = int16_t __attribute__((ext_vector_type(2))); +using int16x4_t = int16_t __attribute__((ext_vector_type(4))); +using int16x8_t = int16_t __attribute__((ext_vector_type(8))); +using int16x16_t = int16_t __attribute__((ext_vector_type(16))); +using int16x32_t = int16_t __attribute__((ext_vector_type(32))); +using int16x64_t = int16_t __attribute__((ext_vector_type(64))); // i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; +using int8x2_t = int8_t __attribute((ext_vector_type(2))); +using int8x4_t = int8_t __attribute((ext_vector_type(4))); +using int8x8_t = int8_t __attribute((ext_vector_type(8))); +using int8x16_t = int8_t __attribute((ext_vector_type(16))); +using int8x32_t = int8_t __attribute((ext_vector_type(32))); +using int8x64_t = int8_t __attribute((ext_vector_type(64))); // f8 -using fp8x2_t = typename vector_type::type; -using fp8x4_t = typename vector_type::type; -using fp8x8_t = typename vector_type::type; -using fp8x16_t = typename vector_type::type; -using fp8x32_t = typename vector_type::type; -using fp8x64_t = typename vector_type::type; +using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); +using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); +using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); +using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16))); +using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32))); +using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); // bf8 -using bf8x2_t = typename vector_type::type; -using bf8x4_t = typename vector_type::type; -using bf8x8_t = typename vector_type::type; -using bf8x16_t = typename vector_type::type; -using bf8x32_t = typename vector_type::type; -using bf8x64_t = typename vector_type::type; +using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); +using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); +using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); +using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16))); +using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32))); +using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index eca5efdffb..efb4f2ad43 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core/config.hpp" -#include "ck_tile/core/arch/amd_address_space.hpp" +#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" @@ -12,6 +12,7 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -22,13 +23,13 @@ namespace ck_tile { // FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of // transforms of tensor_view/Tensor // FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split -// BufferView definition for different memory address space (Global/GenericLds/Vgpr) +// buffer_view definition for different memory address space (Global/GenericLds/Vgpr) template -struct BufferView; +struct buffer_view; // Address Space: generic // T may be scalar or vector @@ -82,17 +83,18 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -123,19 +125,20 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { - if constexpr(Op == InMemoryDataOperationEnum::set) + if constexpr(Op == memory_operation_enum::set) { this->template set(i, is_valid_element, x); } - // FIXME: remove InMemoryDataOperationEnum::Add - else if constexpr(Op == InMemoryDataOperationEnum::Add) + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) { auto tmp = this->template get(i, is_valid_element); this->template set(i, is_valid_element, x + tmp); @@ -144,15 +147,16 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -253,17 +257,18 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -326,16 +331,17 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const { - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -348,15 +354,16 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const { // X is vector of T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -368,27 +375,28 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { - if constexpr(Op == InMemoryDataOperationEnum::set) + if constexpr(Op == memory_operation_enum::set) { this->template set(i, is_valid_element, x); } - else if constexpr(Op == InMemoryDataOperationEnum::atomic_add) + else if constexpr(Op == memory_operation_enum::atomic_add) { this->template atomic_add(i, is_valid_element, x); } - else if constexpr(Op == InMemoryDataOperationEnum::atomic_max) + else if constexpr(Op == memory_operation_enum::atomic_max) { this->template atomic_max(i, is_valid_element, x); } - // FIXME: remove InMemoryDataOperationEnum::Add - else if constexpr(Op == InMemoryDataOperationEnum::Add) + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) { auto tmp = this->template get(i, is_valid_element); this->template set(i, is_valid_element, x + tmp); @@ -399,16 +407,17 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -443,16 +452,17 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -463,17 +473,18 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) { - using scalar_t = typename scalar_type>::type; + using scalar_t = typename vector_traits>::scalar_type; // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -482,15 +493,16 @@ struct buffer_view, int32_t> || - is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + std::is_same_v, int32_t> || + std::is_same_v, float> || + (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0); #elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) - bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; + bool constexpr use_amd_buffer_addressing = + std::is_same_v, int32_t>; #elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = - is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + std::is_same_v, float> || + (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -512,15 +524,16 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -528,8 +541,8 @@ struct buffer_view>::type; - bool constexpr use_amd_buffer_addressing = is_same_v, double>; + using scalar_t = typename vector_traits>::scalar_type; + bool constexpr use_amd_buffer_addressing = std::is_same_v, double>; #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -628,17 +641,18 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -669,19 +683,20 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { - if constexpr(Op == InMemoryDataOperationEnum::set) + if constexpr(Op == memory_operation_enum::set) { this->template set(i, is_valid_element, x); } - // FIXME: remove InMemoryDataOperationEnum::Add - else if constexpr(Op == InMemoryDataOperationEnum::Add) + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) { auto tmp = this->template get(i, is_valid_element); this->template set(i, is_valid_element, x + tmp); @@ -690,15 +705,16 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -709,7 +725,8 @@ struct buffer_view>::type, int8_t>::value && + if constexpr(std::is_same>::scalar_type, + int8_t>::value && workaround_int8_ds_write_issue) { if(is_valid_element) @@ -718,83 +735,83 @@ struct buffer_view" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - static_assert((is_same, int8_t>::value && - is_same, int8_t>::value) || - (is_same, int8_t>::value && - is_same, int8x2_t>::value) || - (is_same, int8_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8_t>::value && - is_same, int8x8_t>::value) || - (is_same, int8_t>::value && - is_same, int8x16_t>::value) || - (is_same, int8x4_t>::value && - is_same, int8x4_t>::value) || - (is_same, int8x8_t>::value && - is_same, int8x8_t>::value) || - (is_same, int8x16_t>::value && - is_same, int8x16_t>::value), + static_assert((std::is_same, int8_t>::value && + std::is_same, int8_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x2_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x4_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x8_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x16_t>::value) || + (std::is_same, int8x4_t>::value && + std::is_same, int8x4_t>::value) || + (std::is_same, int8x8_t>::value && + std::is_same, int8x8_t>::value) || + (std::is_same, int8x16_t>::value && + std::is_same, int8x16_t>::value), "wrong! not implemented for this combination, please add " "implementation"); - if constexpr(is_same, int8_t>::value && - is_same, int8_t>::value) + if constexpr(std::is_same, int8_t>::value && + std::is_same, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8_t>::value && - is_same, int8x2_t>::value) + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8_t>::value && - is_same, int8x4_t>::value) + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8_t>::value && - is_same, int8x8_t>::value) + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8_t>::value && - is_same, int8x16_t>::value) + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8x4_t>::value && - is_same, int8x4_t>::value) + else if constexpr(std::is_same, int8x4_t>::value && + std::is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8x8_t>::value && - is_same, int8x8_t>::value) + else if constexpr(std::is_same, int8x8_t>::value && + std::is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same, int8x16_t>::value && - is_same, int8x16_t>::value) + else if constexpr(std::is_same, int8x16_t>::value && + std::is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix @@ -899,17 +916,18 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + bool oob_conditional_check = true, + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE constexpr auto get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -940,19 +958,20 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) { - if constexpr(Op == InMemoryDataOperationEnum::set) + if constexpr(Op == memory_operation_enum::set) { this->template set(i, is_valid_element, x); } - // FIXME: remove InMemoryDataOperationEnum::Add - else if constexpr(Op == InMemoryDataOperationEnum::Add) + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) { auto tmp = this->template get(i, is_valid_element); this->template set(i, is_valid_element, x + tmp); @@ -961,15 +980,16 @@ struct buffer_view>::type, - typename scalar_type>::type>::value, - bool>::type = false> + typename std::enable_if< + std::is_same>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -1029,7 +1049,7 @@ template , remove_cvref_t>::value, + typename std::enable_if, remove_cvref_t>::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 1d9c7c9c79..288a60602a 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -11,6 +11,9 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/null_tile_window.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" namespace ck_tile { @@ -65,13 +68,13 @@ CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) } template -CK_TILE_DEVICE auto load_tile(const NullTileWindow&) +CK_TILE_DEVICE auto load_tile(const null_tile_window&) { - return NullTensor{}; + return null_tensor{}; } template -CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const NullTileWindow&) +CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window&) { } diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index ad7dd072dc..89806203ab 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 43a8a38e89..edf3e6eebb 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -8,11 +8,13 @@ #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/utility/transpose_vectors.hpp" namespace ck_tile { namespace detail { @@ -74,8 +76,8 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT constexpr index_t num_vec_in = vec_length_out; constexpr index_t num_vec_out = vec_length_in; - using InVec = vector_type; - using OutVec = vector_type; + using InVec = array; + using OutVec = array; using InVecType = typename InVec::type; using OutVecType = typename OutVec::type; @@ -114,7 +116,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); - in_vectors(i).template AsType()(I0) = + in_vectors(i).template get_as()(I0) = in_tensor.get_thread_buffer().template get_as(number{}); }); @@ -134,7 +136,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT out_tensor.get_thread_buffer().template set_as( number{}, - out_vectors[i].template AsType()[I0]); + out_vectors[i].template get_as()[I0]); }); }); } diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 35ef4ac405..59f94a2796 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -84,7 +84,7 @@ set_slice_tile(static_distributed_tensor(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); - static_assert(is_same_v, "wrong!"); + static_assert(std::is_same_v, "wrong!"); dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); } diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 1f90ff2b95..0c9e0debb1 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" namespace ck_tile { diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index 6563f75a06..c12ad883d9 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -25,7 +25,7 @@ store_tile(tile_window_with_static_lengths& t using DataType = remove_cvref_t; using TileDstr = remove_cvref_t; - static_assert(is_same_v, DataType>, "wrong!"); + static_assert(std::is_same_v, DataType>, "wrong!"); constexpr auto tile_dstr = TileDstr{}; @@ -48,7 +48,7 @@ store_tile_raw(tile_window_with_static_lengths; using TileDstr = remove_cvref_t; - static_assert(is_same_v, DataType>, "wrong!"); + static_assert(std::is_same_v, DataType>, "wrong!"); constexpr auto tile_dstr = TileDstr{}; diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 726696e9d7..2a3ecd8f70 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -717,7 +717,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ \ - constexpr auto trans = [&encoded_transforms, &num_transform]() { \ + constexpr auto trans = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) constexpr { \ constexpr auto name = encoded_transforms[i].template at<0>(); \ @@ -725,7 +725,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ \ - STATIC_ASSERT(name == cood_transform_enum::PassThrough || \ + STATIC_ASSERT(name == cood_transform_enum::pass_through || \ name == cood_transform_enum::pad || \ name == cood_transform_enum::embed || \ name == cood_transform_enum::merge || \ @@ -733,7 +733,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. name == cood_transform_enum::replicate, \ ""); \ \ - if constexpr(name == cood_transform_enum::PassThrough) \ + if constexpr(name == cood_transform_enum::pass_through) \ { \ index_t pos = 0; \ auto low_len = meta_data.template pop(pos); \ @@ -841,7 +841,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ \ - constexpr auto trans = [&encoded_transforms, &num_transform]() { \ + constexpr auto trans = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) constexpr { \ constexpr auto name = encoded_transforms[i].template at<0>(); \ @@ -849,7 +849,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ \ - STATIC_ASSERT(name == cood_transform_enum::PassThrough || \ + STATIC_ASSERT(name == cood_transform_enum::pass_through || \ name == cood_transform_enum::pad || \ name == cood_transform_enum::embed || \ name == cood_transform_enum::merge || \ @@ -857,7 +857,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. name == cood_transform_enum::replicate, \ ""); \ \ - if constexpr(name == cood_transform_enum::PassThrough) \ + if constexpr(name == cood_transform_enum::pass_through) \ { \ constexpr index_t low_len = meta_data.template get(0); \ \ @@ -912,7 +912,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. number{}); \ }(); \ \ - constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \ + constexpr auto low_dim_idss = [&encoded_transforms]() { \ return generate_tuple( \ [&encoded_transforms](auto i) { \ constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ @@ -923,7 +923,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&.. number()); \ }(); \ \ - constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \ + constexpr auto up_dim_idss = [&encoded_transforms] { \ return generate_tuple( \ [&encoded_transforms](auto i) { \ constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 0ff9210e5f..aa9cf108c5 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -299,7 +299,7 @@ template & lengths, const tuple& strides, - const offset& offset, + const offset& os, number = number<-1>{}, number = number<-1>{}) { @@ -307,7 +307,7 @@ make_naive_tensor_descriptor_with_offset(const tuple& lengths, const auto element_space_size = detail::calculate_element_space_size_impl( lengths, strides, number<0>{}, long_number<1>{}); - const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + const auto transforms = make_tuple(make_offset_transform(element_space_size, os)); constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); @@ -383,12 +383,12 @@ make_naive_tensor_descriptor_packed(const tuple& lengths, template ::type = false> CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset( const tuple& lengths, - const offset& offset, + const Offset& offset, number = number<-1>{}) { const auto desc_0 = [&]() { diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 9a10ca3af3..e37bd806de 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -3,12 +3,15 @@ #pragma once +#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tensor_descriptor.hpp" +#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -16,15 +19,16 @@ namespace ck_tile { template struct tensor_view { - using BufferView = remove_reference_t; - using DataType = typename BufferView::type; + using buffer_view = remove_reference_t; + using DataType = typename buffer_view::type; using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); CK_TILE_HOST_DEVICE constexpr tensor_view() = default; - CK_TILE_HOST_DEVICE constexpr tensor_view(const BufferView& buffer_view, const TensorDesc& desc) + CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view, + const TensorDesc& desc) : buf_{buffer_view}, desc_{desc} { } @@ -58,12 +62,12 @@ struct tensor_view #endif // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template < - typename X, - bool oob_conditional_check = true, - typename std::enable_if>::type, - typename scalar_type>::type>, - bool>::type = false> + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr remove_cvref_t get_vectorized_elements(const TensorCoord& coord, bool_constant = {}) const @@ -76,12 +80,12 @@ struct tensor_view // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template < - typename X, - bool oob_conditional_check = true, - typename std::enable_if>::type, - typename scalar_type>::type>, - bool>::type = false> + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, const TensorCoord& coord, @@ -93,11 +97,11 @@ struct tensor_view coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); } - template < - typename X, - typename std::enable_if>::type, - typename scalar_type>::type>, - bool>::type = false> + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, const TensorCoord& coord) const { @@ -106,12 +110,12 @@ struct tensor_view // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X - template < - typename X, - bool oob_conditional_check = true, - typename std::enable_if>::type, - typename scalar_type>::type>, - bool>::type = false> + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( const TensorCoord& coord, const X& x, bool_constant = {}) { @@ -121,12 +125,12 @@ struct tensor_view x); } - template < - typename X, - bool oob_conditional_check = true, - typename std::enable_if>::type, - typename scalar_type>::type>, - bool>::type = false> + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( const TensorCoord& coord, const X& x, bool_constant = {}) { @@ -153,7 +157,7 @@ struct tensor_view } // member - BufferView buf_; + buffer_view buf_; TensorDesc desc_; }; @@ -162,7 +166,7 @@ struct null_tensor_view { }; -template CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, @@ -173,7 +177,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, return tensor_view{buffer_view, desc}; } -template {buffer_view, desc}; } -template @@ -228,19 +232,19 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol NewLowerDimensionOldVisibleIdss{}, NewUpperDimensionNewVisibleIdss{}); - return tensor_view>{ + return tensor_view>{ old_tensor_view.buf_, new_desc}; } -template typename DoPads> // sequence CK_TILE_HOST_DEVICE constexpr auto -pad_tensor_view(const tensor_view& tensor_view, const TileLengths& tile_lengths, DoPads) +pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads) { constexpr index_t num_dim = DoPads::size(); - static_assert(num_dim == TileLengths::size() && num_dim == tensor_view::get_num_of_dimension(), + static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(), "wrong! inconsistent # of dimensions"); // transforms diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 8a1a2e4810..38a02acb32 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -3,12 +3,14 @@ #pragma once +#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -293,7 +295,9 @@ CK_TILE_HOST_DEVICE constexpr auto &hidden_dim_cnt, &rh_major_minor_to_hidden_ids, &rh_major_minor_to_hidden_lengths](auto idim_x) { - constexpr auto h_minor_lengths = tuple_element_t{}; + // typename HsLengthss::base{}.foo(); + constexpr auto h_minor_lengths = HsLengthss{}.get(idim_x); //std::tuple_element_t{}; + // constexpr auto h_minor_lengths = impl::getv(HsLengthss{}); constexpr index_t ndim_h_minor = h_minor_lengths.size(); diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index e2b1f0c385..974cb2ee1e 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -18,7 +19,7 @@ namespace ck_tile { template , NullTensor>>...>>> + std::negation, null_tensor>>...>>> CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func, InOutDstrTensors&... inout_dstr_tensors) { @@ -26,7 +27,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element // static_assert(xxx); constexpr index_t thread_buffer_size = - type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size(); + __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size(); static_for<0, thread_buffer_size, 1>{}( [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); }); @@ -35,7 +36,7 @@ CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element template >...>>> + std::conjunction_v>...>>> CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, const InDstrTensors&... in_dstr_tensors) { @@ -43,10 +44,10 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, // TODO: make sure all distributed tensors have same lengths and distribution // static_assert(xxx); - constexpr auto in_tile_dstr = type_pack_element<0, InDstrTensors...>::get_tile_distribution(); + constexpr auto in_tile_dstr = __type_pack_element<0, InDstrTensors...>::get_tile_distribution(); constexpr index_t thread_buffer_size = - type_pack_element<0, InDstrTensors...>::get_thread_buffer_size(); + __type_pack_element<0, InDstrTensors...>::get_thread_buffer_size(); auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); @@ -69,7 +70,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value) } template -CK_TILE_DEVICE void set_tile(NullTensor&, const T&) +CK_TILE_DEVICE void set_tile(null_tensor&, const T&) { } @@ -82,7 +83,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); if constexpr(v == 0 && tensor_bytes % 4 == 0) { - using dvec_t = static_buffer_c; + using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; @@ -96,7 +97,7 @@ CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) } template -CK_TILE_DEVICE void set_tile(NullTensor&, number) +CK_TILE_DEVICE void set_tile(null_tensor&, number) { } @@ -139,7 +140,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) false); // false -> WORD0 constexpr int32_t m0 = 0x05040100; - using vec_t = typename vector_type::type; + using vec_t = array; vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); out_dstr_tensor.get_thread_buffer().template set_as(number{}, d); @@ -157,9 +158,9 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) template CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor) { - if constexpr((ck_tile::is_same_v || - ck_tile::is_same_v)&&ck_tile:: - is_same_v && + if constexpr((std::is_same_v || + std::is_same_v)&&std::is_same_v && (SrcDstrTensors::get_thread_buffer_size() % 4 == 0)) { return cast_tile_pk_fp8x4(src_tensor); @@ -169,23 +170,23 @@ CK_TILE_DEVICE auto cast_tile(const SrcDstrTensors& src_tensor) src_tensor); } -// no-op function for NullTensor arguments +// no-op function for null_tensor arguments template , NullTensor>...>>> + std::disjunction_v, null_tensor>...>>> CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...) { } -// no-op function for NullTensor arguments +// no-op function for null_tensor arguments template , NullTensor>...>>> + std::disjunction_v, null_tensor>...>>> CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...) { - return NullTensor{}; + return null_tensor{}; } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 80e483e6ad..76db527780 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -3,12 +3,15 @@ #pragma once +#include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -85,8 +88,9 @@ struct tile_window_with_static_distribution static constexpr index_t ScalarPerVector = get_vector_dim_y_scalar_per_vector().template at<1>(); - using vector_type_t = vector_type_maker_t; - using vector_t = typename vector_type_t::type; + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = array; private: static constexpr auto scalars_per_access_ = [] { @@ -275,9 +279,8 @@ struct tile_window_with_static_distribution { using Traits = load_store_traits; - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; @@ -300,8 +303,6 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, bool_constant{}); - const vector_type_t vec{vec_value}; - // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( @@ -315,7 +316,7 @@ struct tile_window_with_static_distribution tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); dst_tensor.get_thread_buffer().template at() = - vec.template AsType()[j]; + vec_value.template get_as()[j]; }); // move thread coordinate @@ -341,16 +342,17 @@ struct tile_window_with_static_distribution { using Traits = load_store_traits; - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; static constexpr index_t YElementSize = TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); static_assert(YElementSize % Traits::ScalarPerVector == 0); - using vectorized_tbuf = StaticBuffer; + using vectorized_tbuf = array; + // StaticBuffer; constexpr auto tile_dstr = TileDstr{}; @@ -426,9 +428,9 @@ struct tile_window_with_static_distribution using Traits = load_store_traits; - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; @@ -468,9 +470,9 @@ struct tile_window_with_static_distribution { using Traits = load_store_traits; - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; @@ -487,7 +489,8 @@ struct tile_window_with_static_distribution constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); // read from distributed tensor - vector_type_t vec; + // vector_type_t vec; + vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( @@ -500,11 +503,11 @@ struct tile_window_with_static_distribution constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec.template AsType()(j) = + vec_value.template get_as()(j) = dstr_tensor.get_thread_buffer().template at(); }); - const vector_t vec_value = vec.template AsType().template at<0>(); + // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor get_bottom_tensor_view().template set_vectorized_elements( @@ -530,9 +533,9 @@ struct tile_window_with_static_distribution { using Traits = load_store_traits; - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; static constexpr bool oob_conditional_check = true; @@ -550,7 +553,8 @@ struct tile_window_with_static_distribution constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); // read from distributed tensor - vector_type_t vec; + // vector_type_t vec; + vector_t vec_value; static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( @@ -563,11 +567,11 @@ struct tile_window_with_static_distribution constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec.template AsType()(j) = + vec_value.template get_as()(j) = dstr_tensor.get_thread_buffer().template at(); }); - const vector_t vec_value = vec.template AsType().template at<0>(); + // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor get_bottom_tensor_view() diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index c246c9c456..2cdce94063 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -191,4 +191,18 @@ CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y) std::forward(f), std::forward(x), std::forward(y)); } +// z = predicate ? x : y +template +constexpr auto conditional_expr(X&& x, Y&& y) +{ + if constexpr(predicate) + { + return std::forward(x); + } + else + { + return std::forward(y); + } +} + } // namespace ck_tile diff --git a/include/ck_tile/core/utility/to_sequence.hpp b/include/ck_tile/core/utility/to_sequence.hpp index 1d2c73073d..4db6cfd4a0 100644 --- a/include/ck_tile/core/utility/to_sequence.hpp +++ b/include/ck_tile/core/utility/to_sequence.hpp @@ -9,12 +9,13 @@ // clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode #define TO_SEQUENCE(a, n) \ _Pragma("clang diagnostic push") \ - _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ - ck_tile::sequence) \ + _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"") \ + [a](ck_tile::sequence) \ { \ - return ck_tile::sequence{})...>{}; \ + return ck_tile::sequence{})...>{}; \ } \ - (make_index_sequence{}) _Pragma("clang diagnostic pop") + (ck_tile::make_index_sequence{}); \ + _Pragma("clang diagnostic pop") #else // Macro function diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp new file mode 100644 index 0000000000..acd5dd7b1d --- /dev/null +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// S: scalar type (or it can be non-scalar type) +// NX: # of vector before transpose +// NY: # of vector after transpose +// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data +template +struct transpose_vectors +{ + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = remove_cvref_t; + + using VX = array; + using VY = array; + + CK_TILE_DEVICE void operator()(const array& vx_tuple, array& vy_tuple) + { + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; + constexpr auto I3 = number<3>{}; + constexpr auto I4 = number<4>{}; + + if constexpr(sizeof(S) == 2) + { + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + using S2 = array; // typename array::type; + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // 2 16bitx2 data from vx_tuple to be transposed + const int32_t x_s2_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I2]); + const int32_t x_s2_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I2]); + + constexpr int32_t m0 = 0x05040100; + constexpr int32_t m1 = 0x07060302; + + // transpose 2x2 16bit + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0); + const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1); + + // 2 16bitx2 data after transposed + vy_tuple(iy).template get_as()(ix / I2) = bit_cast(y_s2_0); + vy_tuple(iy + I1).template get_as()(ix / I2) = bit_cast(y_s2_1); + }); + }); + } + else if constexpr(sizeof(S) == 1) + { + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + using S4 = array; // typename array::type; + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // 4 int8x4 data from vx_tuple + const int32_t x_s4_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I4]); + const int32_t x_s4_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I4]); + const int32_t x_s4_2 = + bit_cast(vx_tuple[ix + I2].template get_as()[iy / I4]); + const int32_t x_s4_3 = + bit_cast(vx_tuple[ix + I3].template get_as()[iy / I4]); + + // transpose + int32_t t_s4_0, t_s4_1; + int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; + + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); + y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); + y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + + // 4 int8x4 data from vy_tuple + vy_tuple(iy).template get_as()(ix / I4) = bit_cast(y_s4_0); + vy_tuple(iy + I1).template get_as()(ix / I4) = bit_cast(y_s4_1); + vy_tuple(iy + I2).template get_as()(ix / I4) = bit_cast(y_s4_2); + vy_tuple(iy + I3).template get_as()(ix / I4) = bit_cast(y_s4_3); + }); + }); + } + else + { + static_assert(false, "not implemented"); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_convert.hpp b/include/ck_tile/core/utility/type_convert.hpp deleted file mode 100644 index 4bc3393fd9..0000000000 --- a/include/ck_tile/core/utility/type_convert.hpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include "ck_tile/core/config.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/numeric/bfloat16.hpp" -#include "ck_tile/core/numeric/float8.hpp" - -namespace ck_tile { - -// Convert X to Y, both X and Y are non-const data types. -template || std::is_const_v), bool> = false> -CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - return static_cast(x); -} - -// Convert X to Y, either X or Y is a const data type. -template || std::is_const_v, bool> = false> -CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - using non_const_y = std::remove_const_t; - using non_const_x = std::remove_const_t; - return static_cast(type_convert(x)); -} - -#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ - template <> \ - inline CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ - { \ - return stype_##_to_##dtype_(x); \ - } - -CK_TILE_TYPE_CONVERT(float, fp16_t) -CK_TILE_TYPE_CONVERT(float, bf16_t) -CK_TILE_TYPE_CONVERT(float, fp8_t) -CK_TILE_TYPE_CONVERT(float, bf8_t) - -CK_TILE_TYPE_CONVERT(fp16_t, float) -CK_TILE_TYPE_CONVERT(bf16_t, float) -CK_TILE_TYPE_CONVERT(fp8_t, float) -CK_TILE_TYPE_CONVERT(bf8_t, float) - -} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 7cdc3d2a28..8b8b01a2ac 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -69,4 +69,18 @@ struct nonesuch template