mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK TILE] GEMM with packed i4 (#1885)
* [CK TILE] GEMM with packed i4
* Fixes
* fixes
* fixes
* fixes
[ROCm/composable_kernel commit: 4d9973ec8e]
This commit is contained in:
@@ -35,7 +35,7 @@
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
template <typename DataType>
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
template <typename Tensor>
|
||||
void permute_tensor_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -94,10 +153,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -149,7 +205,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
permute_tensor_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -195,6 +261,11 @@ int run_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
@@ -205,17 +276,18 @@ int run_gemm_example_with_layouts(int argc,
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
M * K * sizeof(ADataType),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
N * K * sizeof(BDataType),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
@@ -228,7 +300,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
M * N * sizeof(CDataType),
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
|
||||
@@ -321,6 +321,15 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
@@ -344,6 +353,15 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -153,12 +153,12 @@ struct array<T, 0>
|
||||
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<array<T, N>>
|
||||
struct vector_traits<array<T, N>, void>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -149,17 +149,24 @@ struct thread_buffer {
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename>
|
||||
template <typename, typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
|
||||
{
|
||||
using scalar_type = typename T::type;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -294,7 +294,7 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
#undef TP_COM_
|
||||
};
|
||||
|
||||
template <typename>
|
||||
template <typename, typename = void>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
|
||||
@@ -376,14 +376,12 @@ struct numeric<bfloat16_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<bfloat16_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
|
||||
@@ -207,9 +207,6 @@ using bf8_t = unsigned _BitInt(8);
|
||||
using bf8_raw_t = uint8_t;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<fp8_t>
|
||||
{
|
||||
@@ -225,6 +222,7 @@ struct numeric_traits<fp8_t>
|
||||
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
|
||||
#endif
|
||||
static constexpr uint8_t abs_mask = 0x7F;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -242,6 +240,7 @@ struct numeric_traits<bf8_t>
|
||||
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
|
||||
#endif
|
||||
static constexpr uint8_t abs_mask = 0x7F;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
|
||||
@@ -223,9 +223,6 @@ struct numeric<half_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<half_t>
|
||||
{
|
||||
@@ -241,6 +238,7 @@ struct numeric_traits<half_t>
|
||||
static constexpr uint16_t NegInf = 0xFC00;
|
||||
static constexpr uint16_t NaN = 0x7C01;
|
||||
static constexpr uint16_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
@@ -383,4 +381,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
#endif
|
||||
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t c;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
|
||||
return c;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
@@ -74,8 +74,6 @@ struct numeric<int8_t>
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<int8_t>
|
||||
@@ -91,6 +89,7 @@ struct numeric_traits<int8_t>
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -77,7 +77,10 @@ struct numeric
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
struct numeric_traits
|
||||
{
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<float>
|
||||
@@ -94,6 +97,7 @@ struct numeric_traits<float>
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
static constexpr int PackedSize = 1;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ struct pk_int4_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
|
||||
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// limits
|
||||
@@ -91,6 +91,16 @@ struct numeric<pk_int4_t>
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_int4_t>
|
||||
{
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -30,17 +31,34 @@ struct native_t
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
// have compiler error
|
||||
namespace impl {
|
||||
|
||||
template <typename T_, index_t N_, typename = void>
|
||||
struct ext_vector;
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector
|
||||
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = typename native_t<remove_cvref_t<T_>>::type;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename T_, index_t N_>
|
||||
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
// struct type is not supported for ext_vector
|
||||
using value_type = typename native_t<T_>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type;
|
||||
@@ -48,6 +66,17 @@ struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
template <typename V_, index_t Vs_, index_t N_>
|
||||
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
|
||||
N_,
|
||||
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
|
||||
{
|
||||
static constexpr index_t N = Vs_ * N_;
|
||||
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
template <typename T, index_t N>
|
||||
@@ -55,10 +84,11 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
|
||||
|
||||
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
|
||||
// ... unless we have other vector_traits specialization
|
||||
template <typename T>
|
||||
template <typename T, typename>
|
||||
struct vector_traits
|
||||
{
|
||||
using scalar_type = remove_cvref_t<T>;
|
||||
using scalar_type =
|
||||
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
@@ -66,7 +96,7 @@ struct vector_traits
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using scalar_type = T;
|
||||
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
@@ -200,21 +230,11 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t vector_res;
|
||||
|
||||
vector_res.x = x.x + y.x;
|
||||
vector_res.y = x.y + y.y;
|
||||
|
||||
return vector_res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
fp16x2_t c;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
|
||||
return c;
|
||||
}
|
||||
|
||||
// pk_int4_t
|
||||
// using pk_int4_t
|
||||
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -231,13 +231,18 @@ struct buffer_view<address_space_enum::global,
|
||||
int32x4_t cached_buf_res_;
|
||||
remove_cvref_t<T> invalid_element_value_ = T{0};
|
||||
|
||||
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view()
|
||||
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
|
||||
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{0}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -245,7 +250,7 @@ struct buffer_view<address_space_enum::global,
|
||||
BufferSizeType buffer_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data},
|
||||
buffer_size_{buffer_size},
|
||||
buffer_size_{buffer_size / PackedSize},
|
||||
cached_buf_res_{0},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
@@ -255,7 +260,7 @@ struct buffer_view<address_space_enum::global,
|
||||
// Must call for buffers that need *_raw load/store
|
||||
CK_TILE_HOST_DEVICE void init_raw()
|
||||
{
|
||||
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
|
||||
cached_buf_res_ = make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
|
||||
@@ -887,8 +892,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
#endif
|
||||
|
||||
i += linear_offset; // simplicity
|
||||
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
int8_t>::value &&
|
||||
if constexpr(std::is_same_v<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
int8_t> &&
|
||||
workaround_int8_ds_write_issue)
|
||||
{
|
||||
if(is_valid_element)
|
||||
@@ -897,83 +902,117 @@ struct buffer_view<address_space_enum::lds,
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
static_assert(
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
|
||||
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8_t>::value)
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x2_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
|
||||
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -27,6 +27,8 @@ struct static_distributed_tensor
|
||||
|
||||
using ThreadTensorDesc =
|
||||
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
|
||||
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
|
||||
@@ -59,7 +61,7 @@ struct static_distributed_tensor
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
|
||||
{
|
||||
return kThreadElementSpaceSize;
|
||||
return kThreadElementSpaceSize / PackedSize;
|
||||
}
|
||||
|
||||
template <index_t... YSliceOrigins, index_t... YSliceLengths>
|
||||
@@ -79,8 +81,9 @@ struct static_distributed_tensor
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
|
||||
sliced_thread_data(
|
||||
number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
|
||||
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
|
||||
});
|
||||
|
||||
return sliced_thread_data;
|
||||
@@ -101,8 +104,9 @@ struct static_distributed_tensor
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
|
||||
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
|
||||
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
|
||||
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
|
||||
PackedSize>{}];
|
||||
});
|
||||
}
|
||||
|
||||
@@ -115,7 +119,7 @@ struct static_distributed_tensor
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
|
||||
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
|
||||
}
|
||||
|
||||
template <typename TileDistributedIndices>
|
||||
@@ -127,11 +131,11 @@ struct static_distributed_tensor
|
||||
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
|
||||
TileDistributedIndices{});
|
||||
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
|
||||
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
|
||||
}
|
||||
|
||||
//
|
||||
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
|
||||
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
|
||||
};
|
||||
|
||||
template <typename DataType, typename StaticTileDistribution>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -45,6 +45,8 @@ struct tensor_view
|
||||
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
|
||||
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
|
||||
static constexpr auto DstInMemOp = DstInMemOp_;
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
|
||||
|
||||
@@ -81,8 +83,8 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
@@ -99,8 +101,8 @@ struct tensor_view
|
||||
bool is_valid_element, // flag
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(coord.get_offset(),
|
||||
linear_offset,
|
||||
return buf_.template get<X>(coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
@@ -122,8 +124,8 @@ struct tensor_view
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -142,8 +144,12 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(dst,
|
||||
coord.get_offset() /
|
||||
PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
@@ -159,8 +165,8 @@ struct tensor_view
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
@@ -178,8 +184,8 @@ struct tensor_view
|
||||
bool is_valid_element) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
@@ -198,8 +204,8 @@ struct tensor_view
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -217,8 +223,11 @@ struct tensor_view
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
return buf_.template async_get_raw<X>(smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
@@ -236,8 +245,8 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -272,8 +281,8 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -292,7 +301,7 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
@@ -310,8 +319,8 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -330,7 +339,7 @@ struct tensor_view
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
@@ -350,8 +359,8 @@ struct tensor_view
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -372,7 +381,7 @@ struct tensor_view
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
|
||||
@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector>;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % Traits::ScalarPerVector == 0);
|
||||
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
|
||||
static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
|
||||
// StaticBuffer<address_space_enum::vgpr,
|
||||
// vector_t,
|
||||
// YElementSize / Traits::ScalarPerVector,
|
||||
@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
|
||||
Traits::PackedSize;
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution
|
||||
},
|
||||
number<NDimY>{});
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
@@ -151,11 +151,13 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector>;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
@@ -498,17 +500,18 @@ struct tile_window_linear
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
@@ -556,17 +559,18 @@ struct tile_window_linear
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
@@ -595,8 +599,9 @@ struct tile_window_linear
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % traits::ScalarPerVector == 0);
|
||||
using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>;
|
||||
static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
@@ -620,7 +625,9 @@ struct tile_window_linear
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
|
||||
traits::PackedSize;
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
@@ -804,16 +811,17 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -852,14 +860,15 @@ struct tile_window_linear
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -897,16 +906,17 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
@@ -948,16 +958,17 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
|
||||
@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
|
||||
double compute_error = 0;
|
||||
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
|
||||
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
|
||||
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -282,7 +282,14 @@ struct FillMonotonicSeq
|
||||
{
|
||||
std::generate(first, last, [=, n = init_value_]() mutable {
|
||||
auto tmp = n;
|
||||
n += step_;
|
||||
if constexpr(std::is_same_v<decltype(tmp), pk_int4_t>)
|
||||
{
|
||||
n.data += step_.data;
|
||||
}
|
||||
else
|
||||
{
|
||||
n += step_;
|
||||
}
|
||||
return tmp;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -281,18 +281,18 @@ struct HostTensor
|
||||
using Data = std::vector<T>;
|
||||
|
||||
template <typename X>
|
||||
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
|
||||
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.get_element_space_size())
|
||||
: mDesc(lens, strides), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Lengths>
|
||||
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
|
||||
HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -302,7 +302,7 @@ struct HostTensor
|
||||
{
|
||||
}
|
||||
|
||||
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {}
|
||||
HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
|
||||
|
||||
template <typename OutT>
|
||||
HostTensor<OutT> CopyAsType() const
|
||||
@@ -340,7 +340,11 @@ struct HostTensor
|
||||
|
||||
std::size_t get_element_size() const { return mDesc.get_element_size(); }
|
||||
|
||||
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
|
||||
std::size_t get_element_space_size() const
|
||||
{
|
||||
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
return mDesc.get_element_space_size() / PackedSize;
|
||||
}
|
||||
|
||||
std::size_t get_element_space_size_in_bytes() const
|
||||
{
|
||||
@@ -463,29 +467,27 @@ struct HostTensor
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
return mDesc.GetOffsetFromMultiIndex(is...);
|
||||
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_element_op(a_m_k(m, k));
|
||||
BDataType v_b = b_element_op(b_k_n(k, n));
|
||||
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
AccDataType acc = 0.0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
|
||||
ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc += v_a * v_b;
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,20 +9,166 @@
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
#if 0
|
||||
// Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
int lo;
|
||||
int hi;
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
fp16x4_t res;
|
||||
|
||||
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(res.lo)
|
||||
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
|
||||
|
||||
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
|
||||
asm volatile(
|
||||
"v_pk_fma_f16 %0, %1, %2, %3"
|
||||
: "=v"(res.hi)
|
||||
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
int lo;
|
||||
int hi;
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
fp16x4_t res;
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(res.lo)
|
||||
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
|
||||
|
||||
asm volatile(
|
||||
"v_pk_fma_f16 %0, %1, %2, %3"
|
||||
: "=v"(res.hi)
|
||||
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388616.f;
|
||||
fp32_intermediates[1] -= 8388616.f;
|
||||
fp32_intermediates[2] -= 8388616.f;
|
||||
fp32_intermediates[3] -= 8388616.f;
|
||||
|
||||
bf16x4_t res;
|
||||
res.lo = bit_cast<bf16x2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
|
||||
res.hi = bit_cast<bf16x2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y.lo = i4_to_half4(bit_cast<int>(x));
|
||||
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y.lo = i4_to_bhalf4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
template <typename Y, typename X, typename Z>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
|
||||
{
|
||||
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
|
||||
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
y = type_convert<fp16x2_t>(t);
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
|
||||
{
|
||||
uint8_t x_u8 = bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
y.lo = type_convert<half_t>(x_l);
|
||||
y.hi = type_convert<half_t>(x_h);
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
|
||||
using BWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
|
||||
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
AWarpTileDistr{}))>;
|
||||
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
|
||||
WarpTile& warp_tile)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size =
|
||||
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
@@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
@@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
AWarpTensor a_warp_tile;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
BWarpTensor b_warp_tile;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
|
||||
@@ -60,6 +60,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
@@ -139,12 +146,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
|
||||
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
|
||||
|
||||
@@ -21,6 +21,13 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
@@ -33,9 +40,11 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
|
||||
MinMemInFlyBytes / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
static constexpr index_t FullMemBandPrefetchStages =
|
||||
integer_divide_ceil(MinMemInFlyBytes / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) / APackedSize +
|
||||
NPerBlock * sizeof(BDataType) / BPackedSize) *
|
||||
KPerBlock);
|
||||
static constexpr index_t PrefetchStages =
|
||||
FullMemBandPrefetchStages >= 2
|
||||
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
|
||||
|
||||
@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
|
||||
constexpr index_t smem_size_a =
|
||||
sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
|
||||
constexpr index_t smem_size_b =
|
||||
sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
|
||||
@@ -20,6 +20,11 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
@@ -37,13 +42,15 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
return integer_divide_ceil(sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>()
|
||||
.get_element_space_size() /
|
||||
APackedSize,
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -75,7 +82,8 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
integer_divide_ceil(
|
||||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
|
||||
@@ -13,14 +13,16 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_>
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(ADataType);
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(BDataType);
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(BDataType);
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +149,14 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_>
|
||||
using GemmPipelineProblem =
|
||||
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -154,14 +165,16 @@ template <typename ADataType_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
|
||||
@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (16 / sizeof(DataType)) == 0)
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (16 / sizeof(DataType));
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (8 / sizeof(DataType)) == 0)
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (8 / sizeof(DataType));
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (4 / sizeof(DataType)) == 0)
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (4 / sizeof(DataType));
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (2 / sizeof(DataType)) == 0)
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (2 / sizeof(DataType));
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
|
||||
Reference in New Issue
Block a user