[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes

[ROCm/composable_kernel commit: 4d9973ec8e]
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 4ffdab09ab
commit 5705af0aad
32 changed files with 882 additions and 305 deletions

View File

@@ -35,7 +35,7 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType>
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmBasicTypeConfig;
template <>
@@ -75,6 +75,15 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
@@ -114,6 +123,12 @@ struct DataTypeTraits<ck_tile::bf8_t>
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename Tensor>
void permute_tensor_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
template <typename ADataType,
typename BDataType,
@@ -83,7 +137,12 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -94,10 +153,7 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using AccDataType = typename GemmBasicTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
@@ -149,7 +205,17 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
permute_tensor_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
@@ -195,6 +261,11 @@ int run_gemm_example_with_layouts(int argc,
}
else if(arg_parser.get_int("v") == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
@@ -205,17 +276,18 @@ int run_gemm_example_with_layouts(int argc,
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
M * K * sizeof(ADataType),
a_m_k.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
N * K * sizeof(BDataType),
b_k_n.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
@@ -228,7 +300,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
M * N * sizeof(CDataType),
c_m_n_dev_result.get_element_space_size_in_bytes(),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));

View File

@@ -321,6 +321,15 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_with_layouts<ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
}
#endif
else
{
throw std::runtime_error("Unsupported data_type!");
@@ -344,6 +353,15 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
}
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
return run_gemm_example_with_layouts<ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
}
#endif
else
{
throw std::runtime_error("Unsupported data_type!");

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -153,12 +153,12 @@ struct array<T, 0>
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
struct vector_traits<array<T, N>, void>
{
using scalar_type = T;
static constexpr index_t vector_size = N;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -149,17 +149,24 @@ struct thread_buffer {
};
// clang-format on
template <typename>
template <typename, typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>>
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
{
using scalar_type = typename T::type;
static constexpr index_t vector_size = N;
};
#endif
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -294,7 +294,7 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#undef TP_COM_
};
template <typename>
template <typename, typename = void>
struct vector_traits;
// specialization for array

View File

@@ -376,14 +376,12 @@ struct numeric<bfloat16_t>
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<bfloat16_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int PackedSize = 1;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE

View File

@@ -207,9 +207,6 @@ using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t;
#endif
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<fp8_t>
{
@@ -225,6 +222,7 @@ struct numeric_traits<fp8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
};
template <>
@@ -242,6 +240,7 @@ struct numeric_traits<bf8_t>
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
static constexpr int PackedSize = 1;
};
// below is sw fp8 conversion, not utilizing hw instruction

View File

@@ -223,9 +223,6 @@ struct numeric<half_t>
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<half_t>
{
@@ -241,6 +238,7 @@ struct numeric_traits<half_t>
static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
@@ -383,4 +381,24 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
@@ -74,8 +74,6 @@ struct numeric<int8_t>
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
@@ -91,6 +89,7 @@ struct numeric_traits<int8_t>
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr int PackedSize = 1;
using bitwise_type = uint16_t;
};
#endif

View File

@@ -77,7 +77,10 @@ struct numeric
};
template <typename T>
struct numeric_traits;
struct numeric_traits
{
static constexpr int PackedSize = 1;
};
template <>
struct numeric_traits<float>
@@ -94,6 +97,7 @@ struct numeric_traits<float>
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t;
};

View File

@@ -21,8 +21,8 @@ struct pk_int4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
};
// limits
@@ -91,6 +91,16 @@ struct numeric<pk_int4_t>
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
};
template <>
struct numeric_traits<pk_int4_t>
{
static constexpr int PackedSize = 2;
};
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -30,17 +31,34 @@ struct native_t
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_, typename = void>
struct ext_vector;
template <typename T_, index_t N_>
struct ext_vector
struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
using value_type = typename native_t<remove_cvref_t<T_>>::type;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
static constexpr index_t N = N_;
// struct type is not supported for ext_vector
using value_type = typename native_t<T_>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
@@ -48,6 +66,17 @@ struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
N_,
std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
@@ -55,10 +84,11 @@ using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T>
template <typename T, typename>
struct vector_traits
{
using scalar_type = remove_cvref_t<T>;
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
static constexpr index_t vector_size = 1;
};
@@ -66,7 +96,7 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type = T;
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
static constexpr index_t vector_size = N;
};
@@ -200,21 +230,11 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
// pk_int4_t
// using pk_int4_t
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -231,13 +231,18 @@ struct buffer_view<address_space_enum::global,
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0};
static constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{0}
{
}
@@ -245,7 +250,7 @@ struct buffer_view<address_space_enum::global,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data},
buffer_size_{buffer_size},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{
@@ -255,7 +260,7 @@ struct buffer_view<address_space_enum::global,
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
cached_buf_res_ = make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
@@ -887,8 +892,8 @@ struct buffer_view<address_space_enum::lds,
#endif
i += linear_offset; // simplicity
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t>::value &&
if constexpr(std::is_same_v<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t> &&
workaround_int8_ds_write_issue)
{
if(is_valid_element)
@@ -897,83 +902,117 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
static_assert(
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value)
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -27,6 +27,8 @@ struct static_distributed_tensor
using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
@@ -59,7 +61,7 @@ struct static_distributed_tensor
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
{
return kThreadElementSpaceSize;
return kThreadElementSpaceSize / PackedSize;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
@@ -79,8 +81,9 @@ struct static_distributed_tensor
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
sliced_thread_data(
number<sliced_thread_tensor_desc.calculate_offset(idx) / PackedSize>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}];
});
return sliced_thread_data;
@@ -101,8 +104,9 @@ struct static_distributed_tensor
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys) / PackedSize>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx) /
PackedSize>{}];
});
}
@@ -115,7 +119,7 @@ struct static_distributed_tensor
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{}];
}
template <typename TileDistributedIndices>
@@ -127,11 +131,11 @@ struct static_distributed_tensor
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize>{});
}
//
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
thread_buffer<DataType, get_thread_buffer_size()> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -45,6 +45,8 @@ struct tensor_view
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
static constexpr auto DstInMemOp = DstInMemOp_;
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
@@ -81,8 +83,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
@@ -99,8 +101,8 @@ struct tensor_view
bool is_valid_element, // flag
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(coord.get_offset(),
linear_offset,
return buf_.template get<X>(coord.get_offset() / PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
@@ -122,8 +124,8 @@ struct tensor_view
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
@@ -142,8 +144,12 @@ struct tensor_view
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(dst,
coord.get_offset() /
PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<pre_nop>{});
}
template <typename X,
@@ -159,8 +165,8 @@ struct tensor_view
{
return buf_.template async_get<X>(
smem,
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
@@ -178,8 +184,8 @@ struct tensor_view
bool is_valid_element) const
{
return buf_.template async_get<X>(smem,
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
@@ -198,8 +204,8 @@ struct tensor_view
{
return buf_.template async_get_raw<X>(
smem,
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
@@ -217,8 +223,11 @@ struct tensor_view
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
return buf_.template async_get_raw<X>(smem,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
is_valid_element,
bool_constant<pre_nop>{});
}
// X is vector of DataType.
@@ -236,8 +245,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
@@ -272,8 +281,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
@@ -292,7 +301,7 @@ struct tensor_view
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
}
// X is vector of DataType.
@@ -310,8 +319,8 @@ struct tensor_view
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
@@ -330,7 +339,7 @@ struct tensor_view
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
}
// X is vector of DataType.
@@ -350,8 +359,8 @@ struct tensor_view
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
@@ -372,7 +381,7 @@ struct tensor_view
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const

View File

@@ -97,13 +97,15 @@ struct tile_window_with_static_distribution
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector>;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
@@ -336,7 +338,7 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -345,10 +347,11 @@ struct tile_window_with_static_distribution
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
@@ -390,8 +393,9 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % Traits::ScalarPerVector == 0);
using vectorized_tbuf = array<vector_t, YElementSize / Traits::ScalarPerVector>;
static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
using vectorized_tbuf =
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
@@ -419,7 +423,8 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
Traits::PackedSize;
static_assert(d % Traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
@@ -632,7 +637,7 @@ struct tile_window_with_static_distribution
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -641,9 +646,10 @@ struct tile_window_with_static_distribution
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -698,7 +704,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -706,8 +712,9 @@ struct tile_window_with_static_distribution
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -759,7 +766,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -768,9 +775,10 @@ struct tile_window_with_static_distribution
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -825,7 +833,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -834,9 +842,10 @@ struct tile_window_with_static_distribution
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
@@ -151,11 +151,13 @@ struct tile_window_linear
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
using vector_t = thread_buffer<DataType, ScalarPerVector>;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
@@ -498,17 +500,18 @@ struct tile_window_linear
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
vec_value.template get_as<DataType>()[j / traits::PackedSize];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
@@ -556,17 +559,18 @@ struct tile_window_linear
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
vec_value.template get_as<DataType>()[j / traits::PackedSize];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
@@ -595,8 +599,9 @@ struct tile_window_linear
using SFC_Ys = typename traits::SFC_Ys;
static constexpr index_t YElementSize =
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
static_assert(YElementSize % traits::ScalarPerVector == 0);
using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>;
static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
using vectorized_tbuf =
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
constexpr auto tile_dstr = TileDstr{};
@@ -620,7 +625,9 @@ struct tile_window_linear
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
traits::PackedSize;
static_assert(d % traits::ScalarPerVector == 0);
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
@@ -804,16 +811,17 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -852,14 +860,15 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -897,16 +906,17 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
@@ -948,16 +958,17 @@ struct tile_window_linear
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
traits::PackedSize;
vec_value.template get_as<DataType>()(j) =
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});

View File

@@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}
@@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}
@@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}
@@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}
@@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}
@@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
{
return 0;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -282,7 +282,14 @@ struct FillMonotonicSeq
{
std::generate(first, last, [=, n = init_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(std::is_same_v<decltype(tmp), pk_int4_t>)
{
n.data += step_.data;
}
else
{
n += step_;
}
return tmp;
});
}

View File

@@ -281,18 +281,18 @@ struct HostTensor
using Data = std::vector<T>;
template <typename X>
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(get_element_space_size())
{
}
template <typename X, typename Y>
HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
: mDesc(lens, strides), mData(mDesc.get_element_space_size())
: mDesc(lens, strides), mData(get_element_space_size())
{
}
template <typename Lengths>
HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size())
{
}
@@ -302,7 +302,7 @@ struct HostTensor
{
}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {}
HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {}
template <typename OutT>
HostTensor<OutT> CopyAsType() const
@@ -340,7 +340,11 @@ struct HostTensor
std::size_t get_element_size() const { return mDesc.get_element_size(); }
std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
std::size_t get_element_space_size() const
{
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.get_element_space_size() / PackedSize;
}
std::size_t get_element_space_size_in_bytes() const
{
@@ -463,29 +467,27 @@ struct HostTensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
return mDesc.GetOffsetFromMultiIndex(is...);
constexpr index_t PackedSize = ck_tile::numeric_traits<remove_cvref_t<T>>::PackedSize;
return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize;
}
template <typename... Is>
T& operator()(Is... is)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
template <typename... Is>
const T& operator()(Is... is) const
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
return mData[GetOffsetFromMultiIndex(is...)];
}
T& operator()(std::vector<std::size_t> idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
const T& operator()(std::vector<std::size_t> idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
return mData[GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> transpose(std::vector<size_t> axes = {}) const

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_element_op(a_m_k(m, k));
BDataType v_b = b_element_op(b_k_n(k, n));
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
@@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
@@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,20 +9,166 @@
namespace ck_tile {
namespace element_wise {
#if 0
// Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
return res;
}
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
return res;
}
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
bf16x4_t res;
res.lo = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.hi = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
}
struct PassThroughPack8
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_half4(bit_cast<int>(x));
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
{
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
}
constexpr const static bool is_pack8_invocable = true;
};
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
#if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
y = type_convert<fp16x2_t>(t);
}
#endif
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
{
uint8_t x_u8 = bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
y.lo = type_convert<half_t>(x_l);
y.hi = type_convert<half_t>(x_h);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{

View File

@@ -1,11 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
using BWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
AWarpTileDistr{}))>;
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
@@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using I0 = number<0>;
using I1 = number<1>;
private:
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
WarpTile& warp_tile)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size =
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
@@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr
});
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
@@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
AWarpTensor a_warp_tile;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
}
else
{
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
}
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
BWarpTensor b_warp_tile;
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
}
else
{
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
}
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
@@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
a_warp_tiles_(mIter)(kIter));
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
b_warp_tiles_(nIter)(kIter));
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}
@@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
a_warp_tiles_(mIter)(kIter));
}
else
{
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
}
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
b_warp_tiles_(nIter)(kIter));
}
else
{
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
}
});
});
}

View File

@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =

View File

@@ -60,6 +60,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
@@ -139,12 +146,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;

View File

@@ -21,6 +21,13 @@ struct BaseGemmPipelineAgBgCrMem
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
static constexpr index_t BlockSize = Problem::kBlockSize;
@@ -33,9 +40,11 @@ struct BaseGemmPipelineAgBgCrMem
static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
MinMemInFlyBytes / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t FullMemBandPrefetchStages =
integer_divide_ceil(MinMemInFlyBytes / WgpPerCU,
(MPerBlock * sizeof(ADataType) / APackedSize +
NPerBlock * sizeof(BDataType) / BPackedSize) *
KPerBlock);
static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8

View File

@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
constexpr index_t smem_size_a =
sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
constexpr index_t smem_size_b =
sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
return smem_size_b;
}
@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),

View File

@@ -20,6 +20,11 @@ struct GemmPipelineAGmemBGmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
@@ -37,13 +42,15 @@ struct GemmPipelineAGmemBGmemCRegV2
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
return integer_divide_ceil(sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>()
.get_element_space_size() /
APackedSize,
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
BPackedSize;
}
template <typename ADramBlockWindowTmp,
@@ -75,7 +82,8 @@ struct GemmPipelineAGmemBGmemCRegV2
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
integer_divide_ceil(
sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) *
16;
// B tile in LDS

View File

@@ -13,14 +13,16 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_>
typename Traits_,
typename ComputeDataType_ = ADataType_>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
@@ -143,9 +149,14 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
typename Traits_,
typename ComputeDataType_ = ADataType_>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
template <typename ADataType_,
typename BDataType_,
@@ -154,14 +165,16 @@ template <typename ADataType_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;

View File

@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (16 / sizeof(DataType));
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
return (PackedSize * 16 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
elements_per_thread % (4 / sizeof(DataType)) == 0)
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
{
return (4 / sizeof(DataType));
return (PackedSize * 8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
elements_per_thread % (2 / sizeof(DataType)) == 0)
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
{
return (2 / sizeof(DataType));
return (PackedSize * 4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
{
return (PackedSize * 2 / sizeof(DataType));
}
else
{
return 1;
return PackedSize;
}
}
@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),