diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 636b34981f..dbc582e5a3 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -35,7 +35,7 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif -template +template struct GemmBasicTypeConfig; template <> @@ -75,6 +75,15 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -114,6 +123,12 @@ struct DataTypeTraits static constexpr const char* name = "bf8"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c9a1b8fc30..f068cbc1da 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K, // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } +template +void permute_tensor_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -94,10 +153,7 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -149,7 +205,17 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(std::is_same_v) + { + // Permute data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + permute_tensor_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); @@ -195,6 +261,11 @@ int run_gemm_example_with_layouts(int argc, } else if(arg_parser.get_int("v") == 2) { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } ck_tile::HostTensor c_m_n_gpu_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); @@ -205,17 +276,18 @@ int run_gemm_example_with_layouts(int argc, BDataType* d_B; CDataType* d_C; - ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); - ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); - ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); + ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); + ck_tile::hip_check_error( + hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); ck_tile::hip_check_error(hipMemcpy(d_A, a_m_k_dev_buf.GetDeviceBuffer(), - M * K * sizeof(ADataType), + a_m_k.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::hip_check_error(hipMemcpy(d_B, b_k_n_dev_buf.GetDeviceBuffer(), - N * K * sizeof(BDataType), + b_k_n.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::reference_gemm_gpu(argc, argv, Row{}, Col{}, Row{}); } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } +#endif else { throw std::runtime_error("Unsupported data_type!"); @@ -344,6 +353,15 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } +#endif else { throw std::runtime_error("Unsupported data_type!"); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 107aae5516..4e0deb1547 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 78768bbbfc..fa63597db4 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,12 +153,12 @@ struct array CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } }; -template +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, void> { using scalar_type = T; static constexpr index_t vector_size = N; diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index 279a48acb3..77c46e1b8c 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -149,17 +149,24 @@ struct thread_buffer { }; // clang-format on -template +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, std::enable_if_t>> { using scalar_type = T; static constexpr index_t vector_size = N; }; +template +struct vector_traits, std::enable_if_t>> +{ + using scalar_type = typename T::type; + static constexpr index_t vector_size = N; +}; + #endif } // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 74575f4c6e..fd02177e25 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -294,7 +294,7 @@ struct tuple : impl::tuple_base, T...> #undef TP_COM_ }; -template +template struct vector_traits; // specialization for array diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 6ad38b1f7c..6f31468809 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -376,14 +376,12 @@ struct numeric } }; -template -struct numeric_traits; - template <> struct numeric_traits { - static constexpr int exp = 8; - static constexpr int mant = 7; + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int PackedSize = 1; }; #if CK_TILE_USE_CUSTOM_DATA_TYPE diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index c4fc6890c6..facc3e45ee 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -207,9 +207,6 @@ using bf8_t = unsigned _BitInt(8); using bf8_raw_t = uint8_t; #endif -template -struct numeric_traits; - template <> struct numeric_traits { @@ -225,6 +222,7 @@ struct numeric_traits static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ; #endif static constexpr uint8_t abs_mask = 0x7F; + static constexpr int PackedSize = 1; }; template <> @@ -242,6 +240,7 @@ struct numeric_traits static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ; #endif static constexpr uint8_t abs_mask = 0x7F; + static constexpr int PackedSize = 1; }; // below is sw fp8 conversion, not utilizing hw instruction diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 5779b170b7..8479b33f8f 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -223,9 +223,6 @@ struct numeric } }; -template -struct numeric_traits; - template <> struct numeric_traits { @@ -241,6 +238,7 @@ struct numeric_traits static constexpr uint16_t NegInf = 0xFC00; static constexpr uint16_t NaN = 0x7C01; static constexpr uint16_t Neg0 = 0x8000; + static constexpr int PackedSize = 1; using bitwise_type = uint16_t; }; @@ -383,4 +381,24 @@ half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))) CK_TILE_DEVICE half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; #endif + +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); + +CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t vector_res; + + vector_res.x = x.x + y.x; + vector_res.y = x.y + y.y; + + return vector_res; +} + +CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y)); + return c; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/int8.hpp b/include/ck_tile/core/numeric/int8.hpp index 9ca3333c39..34d9a1c4b9 100644 --- a/include/ck_tile/core/numeric/int8.hpp +++ b/include/ck_tile/core/numeric/int8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" @@ -74,8 +74,6 @@ struct numeric }; #if 0 -template -struct numeric_traits; template <> struct numeric_traits @@ -91,6 +89,7 @@ struct numeric_traits static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr int PackedSize = 1; using bitwise_type = uint16_t; }; #endif diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp index 6b16485b48..f125fbf2ce 100644 --- a/include/ck_tile/core/numeric/numeric.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -77,7 +77,10 @@ struct numeric }; template -struct numeric_traits; +struct numeric_traits +{ + static constexpr int PackedSize = 1; +}; template <> struct numeric_traits @@ -94,6 +97,7 @@ struct numeric_traits static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr int PackedSize = 1; using bitwise_type = uint32_t; }; diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 2ffcc36ced..541093e337 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -21,8 +21,8 @@ struct pk_int4_t { using type = int8_t; type data; - __host__ __device__ constexpr pk_int4_t() : data{type{}} {} - __host__ __device__ constexpr pk_int4_t(type init) : data{init} {} + CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} + CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} }; // limits @@ -91,6 +91,16 @@ struct numeric CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; } }; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 2; +}; + +using fp32x2_t = float __attribute__((ext_vector_type(2))); +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); + CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 480da96596..b165275a8c 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -30,17 +31,34 @@ struct native_t // of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will // have compiler error namespace impl { + +template +struct ext_vector; + template -struct ext_vector +struct ext_vector::type>>> { static constexpr index_t N = N_; - using value_type = typename native_t>::type; + // struct type is not supported for ext_vector + using value_type = typename native_t::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + +template +struct ext_vector::type>>> +{ + static constexpr index_t N = N_; + // struct type is not supported for ext_vector + using value_type = typename native_t::type::type; static_assert(!std::is_class_v); using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; template -struct ext_vector +struct ext_vector::type>>> { static constexpr index_t N = Vs_ * N_; using value_type = typename native_t>::type; @@ -48,6 +66,17 @@ struct ext_vector using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; +template +struct ext_vector::type>>> +{ + static constexpr index_t N = Vs_ * N_; + using value_type = typename native_t>::type::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + } // namespace impl template @@ -55,10 +84,11 @@ using ext_vector_t = typename impl::ext_vector::type; // by default, any type will result in a vector_size=1 with scalar_type=T traits. // ... unless we have other vector_traits specialization -template +template struct vector_traits { - using scalar_type = remove_cvref_t; + using scalar_type = + std::conditional_t, pk_int4_t>, int8_t, remove_cvref_t>; static constexpr index_t vector_size = 1; }; @@ -66,7 +96,7 @@ struct vector_traits template struct vector_traits { - using scalar_type = T; + using scalar_type = std::conditional_t, int8_t, T>; static constexpr index_t vector_size = N; }; @@ -200,21 +230,11 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); #endif -CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) -{ - fp16x2_t vector_res; - - vector_res.x = x.x + y.x; - vector_res.y = x.y + y.y; - - return vector_res; -} - -CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) -{ - fp16x2_t c; - asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y)); - return c; -} - +// pk_int4_t +// using pk_int4_t +using pk_int4x2_t = int8_t __attribute((ext_vector_type(2))); +using pk_int4x4_t = int8_t __attribute((ext_vector_type(4))); +using pk_int4x8_t = int8_t __attribute((ext_vector_type(8))); +using pk_int4x16_t = int8_t __attribute((ext_vector_type(16))); +using pk_int4x32_t = int8_t __attribute((ext_vector_type(32))); } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 7dffa0e555..c2a093f1ab 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -231,13 +231,18 @@ struct buffer_view invalid_element_value_ = T{0}; + static constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + CK_TILE_HOST_DEVICE constexpr buffer_view() : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} + : p_data_{p_data}, + buffer_size_{buffer_size / PackedSize}, + cached_buf_res_{0}, + invalid_element_value_{0} { } @@ -245,7 +250,7 @@ struct buffer_view>::scalar_type, - int8_t>::value && + if constexpr(std::is_same_v>::scalar_type, + int8_t> && workaround_int8_ds_write_issue) { if(is_valid_element) @@ -897,83 +902,117 @@ struct buffer_view" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - static_assert((std::is_same, int8_t>::value && - std::is_same, int8_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x2_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x4_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x8_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x16_t>::value) || - (std::is_same, int8x4_t>::value && - std::is_same, int8x4_t>::value) || - (std::is_same, int8x8_t>::value && - std::is_same, int8x8_t>::value) || - (std::is_same, int8x16_t>::value && - std::is_same, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); + static_assert( + (std::is_same_v, int8_t> && + std::is_same_v, int8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x2_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, int8x4_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, int8x8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, int8x16_t> && + std::is_same_v, int8x16_t>) || + // ext_vector_type for pk_int4 must use int8_t as type + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x8_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x16_t> && + std::is_same_v, thread_buffer>), + "wrong! not implemented for this combination, please add " + "implementation"); - if constexpr(std::is_same, int8_t>::value && - std::is_same, int8_t>::value) + if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x2_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x2_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x4_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x8_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x16_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x4_t>::value && - std::is_same, int8x4_t>::value) + else if constexpr((std::is_same_v, int8x4_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, pk_int4x4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x8_t>::value && - std::is_same, int8x8_t>::value) + else if constexpr((std::is_same_v, int8x8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, pk_int4x8_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x16_t>::value && - std::is_same, int8x16_t>::value) + else if constexpr((std::is_same_v, int8x16_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, pk_int4x16_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 8d2f88af39..b73a27c8d5 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -27,6 +27,8 @@ struct static_distributed_tensor using ThreadTensorDesc = remove_cvref_t; + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid"); @@ -59,7 +61,7 @@ struct static_distributed_tensor CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size() { - return kThreadElementSpaceSize; + return kThreadElementSpaceSize / PackedSize; } template @@ -79,8 +81,9 @@ struct static_distributed_tensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; - sliced_thread_data(number{}) = - thread_buf_[number{}]; + sliced_thread_data( + number{}) = + thread_buf_[number{}]; }); return sliced_thread_data; @@ -101,8 +104,9 @@ struct static_distributed_tensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; - thread_buf_(number{}) = - sliced_thread_data[number{}]; + thread_buf_(number{}) = + sliced_thread_data[number{}]; }); } @@ -115,7 +119,7 @@ struct static_distributed_tensor constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( TileDistributedIndices{}); - return thread_buf_[number{}]; + return thread_buf_[number{}]; } template @@ -127,11 +131,11 @@ struct static_distributed_tensor constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( TileDistributedIndices{}); - return thread_buf_(number{}); + return thread_buf_(number{}); } // - thread_buffer thread_buf_; + thread_buffer thread_buf_; }; template diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 4c72ed0859..336793c5b1 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,6 +45,8 @@ struct tensor_view using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); static constexpr auto DstInMemOp = DstInMemOp_; + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -81,8 +83,8 @@ struct tensor_view bool_constant = {}) const { return buf_.template get( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -99,8 +101,8 @@ struct tensor_view bool is_valid_element, // flag bool_constant = {}) const { - return buf_.template get(coord.get_offset(), - linear_offset, + return buf_.template get(coord.get_offset() / PackedSize, + linear_offset / PackedSize, is_valid_element, bool_constant{}); } @@ -122,8 +124,8 @@ struct tensor_view { return buf_.template get_raw( dst, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -142,8 +144,12 @@ struct tensor_view bool_constant = {}, bool_constant = {}) const { - return buf_.template get_raw( - dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); + return buf_.template get_raw(dst, + coord.get_offset() / + PackedSize, + linear_offset / PackedSize, + is_valid_element, + bool_constant{}); } template ( smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -178,8 +184,8 @@ struct tensor_view bool is_valid_element) const { return buf_.template async_get(smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, is_valid_element, bool_constant{}); } @@ -198,8 +204,8 @@ struct tensor_view { return buf_.template async_get_raw( smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -217,8 +223,11 @@ struct tensor_view bool is_valid_element, bool_constant = {}) const { - return buf_.template async_get_raw( - smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); + return buf_.template async_get_raw(smem, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, + is_valid_element, + bool_constant{}); } // X is vector of DataType. @@ -236,8 +245,8 @@ struct tensor_view bool_constant = {}) { buf_.template set( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -272,8 +281,8 @@ struct tensor_view bool_constant = {}) { buf_.template set_raw( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -292,7 +301,7 @@ struct tensor_view bool_constant = {}) { buf_.template set_raw( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } // X is vector of DataType. @@ -310,8 +319,8 @@ struct tensor_view bool_constant = {}) { buf_.template update( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -330,7 +339,7 @@ struct tensor_view bool_constant = {}) { buf_.template update( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } // X is vector of DataType. @@ -350,8 +359,8 @@ struct tensor_view bool_constant = {}) { buf_.template update_raw( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -372,7 +381,7 @@ struct tensor_view bool_constant = {}) { buf_.template update_raw( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } CK_TILE_HOST_DEVICE void print() const diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 27c2c24ad5..3bb728df23 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -97,13 +97,15 @@ struct tile_window_with_static_distribution } public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t ScalarPerVector = get_vector_dim_y_scalar_per_vector().template at<1>(); // using vector_type_t = vector_type_maker_t; // using vector_t = typename vector_type_t::type; - using vector_t = thread_buffer; + using vector_t = thread_buffer; private: static constexpr auto scalars_per_access_ = [] { @@ -336,7 +338,7 @@ struct tile_window_with_static_distribution bottom_tensor_thread_coord, 0, bool_constant{}); #if 1 // write into distributed tensor - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -345,10 +347,11 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / Traits::PackedSize]; }); #else constexpr index_t d = @@ -390,8 +393,9 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; static constexpr index_t YElementSize = TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % Traits::ScalarPerVector == 0); - using vectorized_tbuf = array; + static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0); + using vectorized_tbuf = + array; // StaticBuffer( @@ -632,7 +637,7 @@ struct tile_window_with_static_distribution // vector_type_t vec; vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -641,9 +646,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -698,7 +704,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -706,8 +712,9 @@ struct tile_window_with_static_distribution }, number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec_value.template get_as()(j) = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -759,7 +766,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -768,9 +775,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -825,7 +833,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -834,9 +842,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 96a8352c04..1e24e660f6 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/arch/arch.hpp" @@ -151,11 +151,13 @@ struct tile_window_linear } public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t ScalarPerVector = get_vector_dim_y_scalar_per_vector().template at<1>(); - using vector_t = thread_buffer; + using vector_t = thread_buffer; private: static constexpr auto scalars_per_access_ = [] { @@ -498,17 +500,18 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / traits::PackedSize]; }); #else constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); @@ -556,17 +559,18 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / traits::PackedSize]; }); #else constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); @@ -595,8 +599,9 @@ struct tile_window_linear using SFC_Ys = typename traits::SFC_Ys; static constexpr index_t YElementSize = TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % traits::ScalarPerVector == 0); - using vectorized_tbuf = array; + static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0); + using vectorized_tbuf = + array; constexpr auto tile_dstr = TileDstr{}; @@ -620,7 +625,9 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) / + traits::PackedSize; static_assert(d % traits::ScalarPerVector == 0); get_bottom_tensor_view().template get_vectorized_elements_raw( @@ -804,16 +811,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -852,14 +860,15 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec_value.template get_as()(j) = + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -897,16 +906,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -948,16 +958,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index ea70563d58..745c18d6dd 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + static_assert( + is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + static_assert( + is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index f24c338755..006026470b 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -282,7 +282,14 @@ struct FillMonotonicSeq { std::generate(first, last, [=, n = init_value_]() mutable { auto tmp = n; - n += step_; + if constexpr(std::is_same_v) + { + n.data += step_.data; + } + else + { + n += step_; + } return tmp; }); } diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 2047ad7793..a43877c6da 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -281,18 +281,18 @@ struct HostTensor using Data = std::vector; template - HostTensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + HostTensor(std::initializer_list lens) : mDesc(lens), mData(get_element_space_size()) { } template HostTensor(std::initializer_list lens, std::initializer_list strides) - : mDesc(lens, strides), mData(mDesc.get_element_space_size()) + : mDesc(lens, strides), mData(get_element_space_size()) { } template - HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size()) { } @@ -302,7 +302,7 @@ struct HostTensor { } - HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {} + HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {} template HostTensor CopyAsType() const @@ -340,7 +340,11 @@ struct HostTensor std::size_t get_element_size() const { return mDesc.get_element_size(); } - std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); } + std::size_t get_element_space_size() const + { + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + return mDesc.get_element_space_size() / PackedSize; + } std::size_t get_element_space_size_in_bytes() const { @@ -463,29 +467,27 @@ struct HostTensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize; } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + return mData[GetOffsetFromMultiIndex(is...)]; } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + return mData[GetOffsetFromMultiIndex(is...)]; } - T& operator()(std::vector idx) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + T& operator()(std::vector idx) { return mData[GetOffsetFromMultiIndex(idx)]; } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + return mData[GetOffsetFromMultiIndex(idx)]; } HostTensor transpose(std::vector axes = {}) const diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index da0de457d4..fe5077083c 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, for(std::size_t k = 0; k < K; ++k) { - ADataType v_a = a_element_op(a_m_k(m, k)); - BDataType v_b = b_element_op(b_k_n(k, n)); - - v_acc += - ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + } + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); @@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A, AccDataType acc = 0.0; for(int k = 0; k < K; ++k) { + constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; + constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; // Adjust indexing based on matrix layout int a_index = (std::is_same_v) ? row * strideA + k @@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A, int b_index = (std::is_same_v) ? col * strideB + k : k * strideB + col; - acc += ck_tile::type_convert(A[a_index]) * - ck_tile::type_convert(B[b_index]); + + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(A[a_index]); + } + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(B[b_index]); + } + acc += v_a * v_b; } int c_index = (std::is_same_v) diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3e8dac30ef..a3a0df996d 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,20 +9,166 @@ namespace ck_tile { namespace element_wise { -#if 0 +// Fast int4x4 to fp16x8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +CK_TILE_DEVICE fp16x4_t i4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + int lo; + int hi; + // Extract the two int4 at low bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX)); + // Extract the two int4 at hight bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX)); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + fp16x4_t res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(res.lo) + : "v"(bit_cast(lo)), "v"(bit_cast(SUB))); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + asm volatile( + "v_pk_fma_f16 %0, %1, %2, %3" + : "=v"(res.hi) + : "v"(bit_cast(hi)), "v"(bit_cast(MUL)), "v"(bit_cast(ADD))); + + return res; +} + +CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + int lo; + int hi; + // Extract the two int4 at low bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX)); + // Extract the two int4 at hight bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX)); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + fp16x4_t res; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(res.lo) + : "v"(bit_cast(lo)), "v"(bit_cast(SUB))); + + asm volatile( + "v_pk_fma_f16 %0, %1, %2, %3" + : "=v"(res.hi) + : "v"(bit_cast(hi)), "v"(bit_cast(MUL)), "v"(bit_cast(ADD))); + + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale)); + + return res; +} + +CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + bf16x4_t res; + res.lo = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.hi = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res; +} + +struct PassThroughPack8 +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const + { + y.lo = i4_to_half4(bit_cast(x)); + y.hi = i4_to_half4(bit_cast(x) >> 8); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const + { + y.lo = i4_to_bhalf4(bit_cast(x)); + y.hi = i4_to_bhalf4(bit_cast(x) >> 16); + } + constexpr const static bool is_pack8_invocable = true; +}; + +struct DequantPack8 +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const; + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const + { + y.lo = i4_to_half4_scale(bit_cast(x), z); + y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); + } + + constexpr const static bool is_pack8_invocable = true; +}; + struct PassThroughPack2 { template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; - CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const +#if 0 + CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const { auto t = type_convert(x); - y = type_convert(t); + y = type_convert(t); } +#endif + + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const + { + uint8_t x_u8 = bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; + + y.lo = type_convert(x_l); + y.hi = type_convert(x_h); + } + constexpr const static bool is_pack2_invocable = true; }; -#endif struct PassThrough { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index ab21398b99..d9d6739fb5 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise.hpp" namespace ck_tile { @@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr template struct GemmTraits_ { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr auto Scheduler = Problem::Scheduler; @@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr using BWarpTileDistr = remove_cvref_t; - using AWarpTile = - remove_cvref_t(AWarpTileDistr{}))>; - using BWarpTile = - remove_cvref_t(BWarpTileDistr{}))>; + using AWarpTile = remove_cvref_t( + AWarpTileDistr{}))>; + using BWarpTile = remove_cvref_t( + BWarpTileDistr{}))>; // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr public: using Traits = GemmTraits_; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; using WarpGemm = remove_cvref_t; @@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr static constexpr auto Scheduler = Traits::Scheduler; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using I0 = number<0>; using I1 = number<1>; private: + template + CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window, + WarpTile& warp_tile) + { + constexpr index_t UnaryOpSize = 8; + const element_wise::PassThroughPack8 elementwise_op{}; + constexpr index_t thread_buffer_size = + Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } + template struct BlockGemmImpl { @@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr }); using CWarpDstr = typename WarpGemm::CWarpDstr; + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto c_warp_y_lengths = @@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); + AWarpTensor a_warp_tile; + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile); + } + else + { + a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); + } static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); + BWarpTensor b_warp_tile; + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile); + } + else + { + b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); + } // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block window - load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), + a_warp_tiles_(mIter)(kIter)); + } + else + { + a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); + } }); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window - load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), + b_warp_tiles_(nIter)(kIter)); + } + else + { + b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); + } }); }); } @@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr // TODO check if a_warp_tiles has same desc as a_warp_window static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), + a_warp_tiles_(mIter)(kIter)); + } + else + { + a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); + } }); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window - load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), + b_warp_tiles_(nIter)(kIter)); + } + else + { + b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); + } }); }); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 69c50c7cd0..73d5ce8f81 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; @@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 - ? A_LDS_Read_Inst_Num - : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 - ? B_LDS_Read_Inst_Num - : B_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; @@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto ds_read_a_issue_cycle = - A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index ea8d063fd5..b679f8c8aa 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -60,6 +60,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; @@ -139,12 +146,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); - constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 - ? A_LDS_Read_Inst_Num - : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 - ? B_LDS_Read_Inst_Num - : B_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index cde31f087b..b8b2d5b1c9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -21,6 +21,13 @@ struct BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } static constexpr index_t BlockSize = Problem::kBlockSize; @@ -33,9 +40,11 @@ struct BaseGemmPipelineAgBgCrMem static constexpr index_t WgpPerCU = (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; - static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil( - MinMemInFlyBytes / WgpPerCU, - (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t FullMemBandPrefetchStages = + integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, + (MPerBlock * sizeof(ADataType) / APackedSize + + NPerBlock * sizeof(BDataType) / BPackedSize) * + KPerBlock); static constexpr index_t PrefetchStages = FullMemBandPrefetchStages >= 2 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 2d9f95627c..c7115c8eb4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + constexpr index_t smem_size_a = + sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size() / PackedSize; return smem_size_a; } template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * - MakeBLdsBlockDescriptor().get_element_space_size(); + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + constexpr index_t smem_size_b = + sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size() / PackedSize; return smem_size_b; } @@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; using BlockGemmShape = remove_cvref_t; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; @@ -37,13 +42,15 @@ struct GemmPipelineAGmemBGmemCRegV2 CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { - return integer_divide_ceil( - sizeof(ADataType) * - Policy::template MakeALdsBlockDescriptor().get_element_space_size(), - 16) * + return integer_divide_ceil(sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor() + .get_element_space_size() / + APackedSize, + 16) * 16 + sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + Policy::template MakeBLdsBlockDescriptor().get_element_space_size() / + BPackedSize; } template (p_a_lds, a_lds_block_desc); constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + integer_divide_ceil( + sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) * 16; // B tile in LDS diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 771662f566..f833ccc849 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -13,14 +13,16 @@ template + typename Traits_, + typename ComputeDataType_ = ADataType_> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -53,13 +55,15 @@ struct GemmPipelineProblemBase CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < VectorLoadSize / sizeof(ADataType) + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType) ? pixels_per_thread - : VectorLoadSize / sizeof(ADataType); + : PackedSize * VectorLoadSize / sizeof(ADataType); } else { @@ -69,17 +73,19 @@ struct GemmPipelineProblemBase CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < VectorLoadSize / sizeof(BDataType) + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType) ? pixels_per_thread - : VectorLoadSize / sizeof(BDataType); + : PackedSize * VectorLoadSize / sizeof(BDataType); } else { - return VectorLoadSize / sizeof(BDataType); + return PackedSize * VectorLoadSize / sizeof(BDataType); } } @@ -143,9 +149,14 @@ template -using GemmPipelineProblem = - GemmPipelineProblemBase; + typename Traits_, + typename ComputeDataType_ = ADataType_> +using GemmPipelineProblem = GemmPipelineProblemBase; template + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c20d09cea4..fd1e76a02b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; // Assume DataType is even! - if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 && - elements_per_thread % (16 / sizeof(DataType)) == 0) + if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && + PackedSize == 2) { - return (16 / sizeof(DataType)); + return (PackedSize * 32 / sizeof(DataType)); } - else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 && - elements_per_thread % (8 / sizeof(DataType)) == 0) + else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) { - return (8 / sizeof(DataType)); + return (PackedSize * 16 / sizeof(DataType)); } - else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 && - elements_per_thread % (4 / sizeof(DataType)) == 0) + else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) { - return (4 / sizeof(DataType)); + return (PackedSize * 8 / sizeof(DataType)); } - else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 && - elements_per_thread % (2 / sizeof(DataType)) == 0) + else if constexpr(sizeof(DataType) >= PackedSize * 4 && + XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) { - return (2 / sizeof(DataType)); + return (PackedSize * 4 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 2 && + XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) + { + return (PackedSize * 2 / sizeof(DataType)); } else { - return 1; + return PackedSize; } } @@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher