mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Convolution forward transfer concepts. (#3535)
* Rename member variable to better reflect its actuall meaning. * Add transfer checks for conv fwd xdl. * Validate tensor layouts & vector size conv fwd v3. * Add combined transfer concepts. * Add transfer concepts for conv fwd factories. * Fix clang format * Add helper instruction to get max mem vector instruction width. * Apply review comments. * Rename thread cluster access(->arrange) order concept * FIx merge artifacts. * Add generic access order limits into block transfer concept.
This commit is contained in:
@@ -104,7 +104,7 @@ concept EpilogueDescriptor = requires(T t) {
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
concept ThreadClusterOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
} || requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
@@ -231,16 +231,16 @@ concept SpecifiesLdsTransfer = requires(T t) {
|
||||
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
concept SpecifiesThreadClusterArrangeOrder = requires(T t) {
|
||||
{ T::transfer.a.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
|
||||
{ T::transfer.b.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
{ T::transfer.a.src_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.src_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.a.src_access_order } -> ThreadClusterOrderDescriptor;
|
||||
{ T::transfer.b.src_access_order } -> ThreadClusterOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <utility>
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -45,4 +48,224 @@ concept AccessOrderLimits4D = requires {
|
||||
(Value.Size() == 4));
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper to check if access order is a valid permutation
|
||||
template <auto Value>
|
||||
constexpr bool is_valid_permutation()
|
||||
{
|
||||
constexpr auto size = Value.Size();
|
||||
|
||||
// Check all values are in range [0, size)
|
||||
for(size_t i = 0; i < size; ++i)
|
||||
{
|
||||
if(Value[i] < 0 || Value[i] >= static_cast<decltype(Value[0])>(size))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check all values are unique (valid permutation)
|
||||
for(size_t i = 0; i < size; ++i)
|
||||
{
|
||||
for(size_t j = i + 1; j < size; ++j)
|
||||
{
|
||||
if(Value[i] == Value[j])
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Generic access order limits. Must be a valid permutation of {0, 1, ..., Dims-1}.
|
||||
// Works with both 3D and 4D (or any dimensionality) access orders.
|
||||
template <auto Value, size_t Dims>
|
||||
concept AccessOrderLimits = requires {
|
||||
requires Value.Size() == Dims;
|
||||
requires detail::is_valid_permutation<Value>();
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper trait to get compile-time size from ck::Array
|
||||
template <typename T>
|
||||
concept HasStaticSize = requires {
|
||||
{ T::Size() } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Helper trait to get compile-time size from std::array and similar
|
||||
template <typename T>
|
||||
concept HasTupleSize = requires {
|
||||
{ std::tuple_size<T>::value } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Helper for dependent static_assert
|
||||
template <typename>
|
||||
constexpr bool always_false = false;
|
||||
|
||||
// Get compile-time size of a range
|
||||
template <typename Range>
|
||||
constexpr size_t get_range_size()
|
||||
{
|
||||
if constexpr(HasStaticSize<Range>)
|
||||
{
|
||||
return Range::Size();
|
||||
}
|
||||
else if constexpr(HasTupleSize<Range>)
|
||||
{
|
||||
return std::tuple_size_v<Range>;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(always_false<Range>, "Unsupported type of range object.");
|
||||
}
|
||||
}
|
||||
|
||||
// Fold expression implementation for product calculation
|
||||
template <typename Range, size_t... Is>
|
||||
constexpr auto get_cluster_size_impl(const Range& range, std::index_sequence<Is...>)
|
||||
{
|
||||
using value_type = std::remove_cvref_t<decltype(range[0])>;
|
||||
return ((range[Is]) * ... * value_type{1});
|
||||
}
|
||||
|
||||
// Generic function that calculates the product of all elements in a range
|
||||
// Works with any indexable range with compile-time size (ck::Array, std::array, etc.)
|
||||
template <typename Range>
|
||||
requires requires(Range r) {
|
||||
r[0]; // Must be indexable
|
||||
get_range_size<Range>(); // Must have compile-time size
|
||||
}
|
||||
constexpr auto get_cluster_size(const Range& range)
|
||||
{
|
||||
return get_cluster_size_impl(range, std::make_index_sequence<get_range_size<Range>()>{});
|
||||
}
|
||||
|
||||
// Calculate K dimension coverage (k0 * k1, with vectorization if applicable)
|
||||
template <auto BlockTransfer>
|
||||
constexpr auto get_k_coverage()
|
||||
{
|
||||
auto k0 = BlockTransfer.thread_cluster_dims[0];
|
||||
auto k1 = BlockTransfer.thread_cluster_dims[2];
|
||||
auto k_total = k0 * k1;
|
||||
|
||||
// If vectorization is on k0 (dim 0) or k1 (dim 2), multiply by vector size
|
||||
if constexpr(BlockTransfer.src_vector_dim == 0 || BlockTransfer.src_vector_dim == 2)
|
||||
{
|
||||
k_total *= BlockTransfer.src_scalar_per_vector;
|
||||
}
|
||||
|
||||
return k_total;
|
||||
}
|
||||
|
||||
// Calculate M/N dimension coverage (m_n, with vectorization if applicable)
|
||||
template <auto BlockTransfer>
|
||||
constexpr auto get_mn_coverage()
|
||||
{
|
||||
auto mn = BlockTransfer.thread_cluster_dims[1];
|
||||
|
||||
// If vectorization is on m_n (dim 1), multiply by vector size
|
||||
if constexpr(BlockTransfer.src_vector_dim == 1)
|
||||
{
|
||||
mn *= BlockTransfer.src_scalar_per_vector;
|
||||
}
|
||||
|
||||
return mn;
|
||||
}
|
||||
|
||||
template <size_t DataTypeSize>
|
||||
constexpr auto get_data_max_vec_size()
|
||||
{
|
||||
constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width();
|
||||
static_assert(max_vec_inst_size_bytes % DataTypeSize == 0,
|
||||
"The max vec instruction size is not a multiple of given data type size.");
|
||||
return max_vec_inst_size_bytes / DataTypeSize;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// product of thread cluster lengths must be <= workgroup size
|
||||
template <auto BlockTransfer, size_t BlockSize>
|
||||
concept ValidBlockTransferClusterSize =
|
||||
requires { requires detail::get_cluster_size(BlockTransfer.thread_cluster_dims) <= BlockSize; };
|
||||
|
||||
// Check that thread cluster covers the K and M dimensions for A transfer
|
||||
template <auto ABlockTransfer, auto TileSize>
|
||||
concept ThreadsCoverATile = requires {
|
||||
// K dimension: k0 * k1 * (vectorization) must divide K
|
||||
requires TileSize.k % detail::get_k_coverage<ABlockTransfer>() == 0;
|
||||
// M dimension: m_n * (vectorization) must divide M
|
||||
requires TileSize.m % detail::get_mn_coverage<ABlockTransfer>() == 0;
|
||||
};
|
||||
|
||||
// Check that thread cluster covers the K and N dimensions for B transfer
|
||||
template <auto BBlockTransfer, auto TileSize>
|
||||
concept ThreadsCoverBTile = requires {
|
||||
// K dimension: k0 * k1 * (vectorization) must divide K
|
||||
requires TileSize.k % detail::get_k_coverage<BBlockTransfer>() == 0;
|
||||
// N dimension: m_n * (vectorization) must divide N
|
||||
requires TileSize.n % detail::get_mn_coverage<BBlockTransfer>() == 0;
|
||||
};
|
||||
|
||||
template <auto CBlockTransfer, auto TileSize>
|
||||
concept ThreadsCoverCTile = requires {
|
||||
// M dimension: m_wave_per_xdl must divide M
|
||||
requires TileSize.m % CBlockTransfer.thread_cluster_dims[1] == 0;
|
||||
// N dimension: n_wave_per_xdl * (vectorization) must divide N
|
||||
requires TileSize.n % (CBlockTransfer.thread_cluster_dims[3] *
|
||||
CBlockTransfer.scalar_per_vector) == 0;
|
||||
};
|
||||
|
||||
template <size_t Value>
|
||||
concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0);
|
||||
|
||||
template <size_t ScalarPerVec, size_t DataTypeSize>
|
||||
concept IsVectorSizeValid =
|
||||
IsPowerOf2<ScalarPerVec> && (ScalarPerVec <= detail::get_data_max_vec_size<DataTypeSize>());
|
||||
|
||||
// Composite concept for input block transfer validation (A)
|
||||
// Includes all validations: vector transfer limits, access order, cluster size,
|
||||
// vector size validity, and tile coverage
|
||||
template <auto A_BLOCK_TRANSFER,
|
||||
typename DataType,
|
||||
size_t BLOCK_SIZE,
|
||||
auto TILE_SIZE,
|
||||
size_t DIMS = 3>
|
||||
concept ValidABlockTransfer =
|
||||
InputVectorTransferLimits<A_BLOCK_TRANSFER> &&
|
||||
AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
|
||||
AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order, DIMS> &&
|
||||
ValidBlockTransferClusterSize<A_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<A_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
|
||||
IsVectorSizeValid<A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverATile<A_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
|
||||
// Composite concept for input block transfer validation (B)
|
||||
template <auto B_BLOCK_TRANSFER,
|
||||
typename DataType,
|
||||
size_t BLOCK_SIZE,
|
||||
auto TILE_SIZE,
|
||||
size_t DIMS = 3>
|
||||
concept ValidBBlockTransfer =
|
||||
InputVectorTransferLimits<B_BLOCK_TRANSFER> &&
|
||||
AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
|
||||
AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order, DIMS> &&
|
||||
ValidBlockTransferClusterSize<B_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<B_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
|
||||
IsVectorSizeValid<B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverBTile<B_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
|
||||
// Composite concept for output block transfer validation (C)
|
||||
template <auto C_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
|
||||
concept ValidCBlockTransfer =
|
||||
OutputVectorTransferLimits<C_BLOCK_TRANSFER> &&
|
||||
ValidBlockTransferClusterSize<C_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<C_BLOCK_TRANSFER.scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverCTile<C_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
|
||||
// Usage: IsValidLayout<ACTUAL_LAYOUT, VALID_LAYOUT_1, VALID_LAYOUT_2, ...>
|
||||
template <auto ACTUAL_LAYOUT, auto... VALID_LAYOUTS>
|
||||
concept IsValidLayout = ck_tile::is_any_value_of(ACTUAL_LAYOUT, VALID_LAYOUTS...);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace ck_tile::builder::factory {
|
||||
template <typename T, size_t ThreadClusterRank = 3>
|
||||
concept TileTransferParameters =
|
||||
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
|
||||
SpecifiesThreadClusterArrangeOrder<T> && SpecifiesSourceAccessOrder<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
|
||||
|
||||
@@ -46,14 +46,55 @@ struct ConvFwdLargeTensorFactory
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
// Check limits for the data transfer parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
using enum TensorLayout;
|
||||
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
|
||||
G_NW_C_strided,
|
||||
G_NHW_C_strided,
|
||||
G_NDHW_C_strided,
|
||||
GNWC,
|
||||
GNHWC,
|
||||
GNDHWC,
|
||||
NWGC,
|
||||
NHWGC,
|
||||
NDHWGC> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
G_K_YX_C_strided,
|
||||
G_K_ZYX_C_strided,
|
||||
GKXC,
|
||||
GKYXC,
|
||||
GKZYXC,
|
||||
KXGC,
|
||||
KYXGC,
|
||||
KZYXGC> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
G_NHW_K_strided,
|
||||
G_NDHW_K_strided,
|
||||
GNWK,
|
||||
GNHWK,
|
||||
GNDHWK,
|
||||
NWGK,
|
||||
NHWGK,
|
||||
NDHWGK>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
|
||||
@@ -52,14 +52,64 @@ struct ConvFwdXdlV3Factory
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
// Layout validations
|
||||
using enum TensorLayout;
|
||||
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
|
||||
G_NW_C_strided,
|
||||
G_NHW_C_strided,
|
||||
G_NDHW_C_strided,
|
||||
GNWC,
|
||||
GNHWC,
|
||||
GNDHWC,
|
||||
NWGC,
|
||||
NHWGC,
|
||||
NDHWGC,
|
||||
NGCW,
|
||||
NGCHW,
|
||||
NGCDHW> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
G_K_YX_C_strided,
|
||||
G_K_ZYX_C_strided,
|
||||
GKXC,
|
||||
GKYXC,
|
||||
GKZYXC,
|
||||
KXGC,
|
||||
KYXGC,
|
||||
KZYXGC,
|
||||
GKCX,
|
||||
GKCYX,
|
||||
GKCZYX> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
G_NHW_K_strided,
|
||||
G_NDHW_K_strided,
|
||||
GNWK,
|
||||
GNHWK,
|
||||
GNDHWK,
|
||||
NWGK,
|
||||
NHWGK,
|
||||
NDHWGK,
|
||||
NGKW,
|
||||
NGKHW,
|
||||
NGKDHW>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
|
||||
@@ -48,14 +48,56 @@ struct ConvFwdWmmaFactory
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
// TODO: verify Ds transfer as well
|
||||
|
||||
// Layout validations (same as DeviceGroupedConvFwdMultipleD_Wmma_CShuffle)
|
||||
using enum TensorLayout;
|
||||
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
|
||||
G_NW_C_strided,
|
||||
G_NHW_C_strided,
|
||||
G_NDHW_C_strided,
|
||||
GNWC,
|
||||
GNHWC,
|
||||
GNDHWC,
|
||||
NWGC,
|
||||
NHWGC,
|
||||
NDHWGC> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
G_K_YX_C_strided,
|
||||
G_K_ZYX_C_strided,
|
||||
GKXC,
|
||||
GKYXC,
|
||||
GKZYXC,
|
||||
KXGC,
|
||||
KYXGC,
|
||||
KZYXGC> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
G_NHW_K_strided,
|
||||
G_NDHW_K_strided,
|
||||
GNWK,
|
||||
GNHWK,
|
||||
GNDHWK,
|
||||
NWGK,
|
||||
NHWGK,
|
||||
NDHWGK>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
|
||||
@@ -47,14 +47,64 @@ struct ConvFwdXdlFactory
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
// Layout validations (same as DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle)
|
||||
using enum TensorLayout;
|
||||
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
|
||||
G_NW_C_strided,
|
||||
G_NHW_C_strided,
|
||||
G_NDHW_C_strided,
|
||||
GNWC,
|
||||
GNHWC,
|
||||
GNDHWC,
|
||||
NWGC,
|
||||
NHWGC,
|
||||
NDHWGC,
|
||||
NGCW,
|
||||
NGCHW,
|
||||
NGCDHW> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
G_K_YX_C_strided,
|
||||
G_K_ZYX_C_strided,
|
||||
GKXC,
|
||||
GKYXC,
|
||||
GKZYXC,
|
||||
KXGC,
|
||||
KYXGC,
|
||||
KZYXGC,
|
||||
GKCX,
|
||||
GKCYX,
|
||||
GKCZYX> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
G_NHW_K_strided,
|
||||
G_NDHW_K_strided,
|
||||
GNWK,
|
||||
GNHWK,
|
||||
GNDHWK,
|
||||
NWGK,
|
||||
NHWGK,
|
||||
NDHWGK,
|
||||
NGKW,
|
||||
NGKHW,
|
||||
NGKDHW>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
|
||||
@@ -27,7 +27,7 @@ template <auto TRANSFER>
|
||||
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& block_order = TRANSFER.thread_cluster_arrange_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
@@ -47,7 +47,7 @@ template <auto TRANSFER>
|
||||
constexpr auto SetBwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& block_order = TRANSFER.thread_cluster_arrange_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
|
||||
@@ -126,15 +126,15 @@ struct AccessOrder
|
||||
{
|
||||
std::array<size_t, ThreadSliceLength> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
|
||||
static_assert(ThreadClusterOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(ThreadClusterOrderDescriptor<AccessOrder<4>>);
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct InputTransfer
|
||||
{
|
||||
BlockTransfer<ThreadSliceLength> block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
AccessOrder<ThreadSliceLength> block_transfer_access_order;
|
||||
AccessOrder<ThreadSliceLength> thread_cluster_arrange_order;
|
||||
AccessOrder<ThreadSliceLength> src_access_order;
|
||||
};
|
||||
|
||||
|
||||
@@ -128,26 +128,26 @@ struct DefaultAlgorithm
|
||||
ckb::test::Transfer<> transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
|
||||
@@ -53,25 +53,25 @@ constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
|
||||
constexpr Transfer<> Transfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 4,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -86,25 +86,25 @@ constexpr Transfer<> Transfer_4x64x1{
|
||||
constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -119,25 +119,25 @@ constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -152,25 +152,25 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -185,25 +185,25 @@ constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
constexpr Transfer<> Transfer_4x16x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
@@ -219,25 +219,25 @@ constexpr Transfer<> Transfer_4x16x1{
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 16,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
|
||||
@@ -165,7 +165,7 @@ template <size_t N = 3>
|
||||
inline std::string to_string(InputTransfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.thread_cluster_arrange_order) << ","
|
||||
<< to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << ","
|
||||
<< t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector
|
||||
<< "," << (t.lds_transfer.lds_padding ? "true" : "false");
|
||||
|
||||
@@ -1173,4 +1173,11 @@ enum LLVMSchedGroupMask : int32_t
|
||||
DS_WRITE = 1 << 9,
|
||||
ALL = (DS_WRITE << 1) - 1,
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_max_mem_vec_inst_width()
|
||||
{
|
||||
// Currently on all arch max memory vector instruction width is 16 bytes.
|
||||
return 16;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user