mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-14 20:27:42 +00:00
[rocm-libraries] ROCm/rocm-libraries#4750 (commit c065793)
[CK_BUILDER] ck builder conv transfer fix ## Motivation This PR fixes how CK Builder is validating transfer vector size and adds proper validation for LDS transfer vector size as well. ## Changes: * [__source vector dim__] -- Before this PR the data transfer validation logic didn't allow to set the source vectorized dimension to 1. However there are CK instances that are doing this when the group merging is used. This is used only for `DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle` kernel. * [__valid vector size__] -- Before this PR the validation logic concerned only single instruction maximum vector size. However our buffer loading logic has implemented support for loading more values through multiple buffer instructions. This again was discovered to be used in some of the convolution instances. Thus this behavior was reflected in validation logic. * [__valid LDS vector size__] -- Before this PR the LDS vector size validation was done in the same way as VMEM. This PR adds proper LDS vector size validation based on the available LDS instruction sizes. ## Test Plan Run CK BUILDER conv fwd factories tests ## Test Result All CK BUILDER conv fwd factories work (except DL one & ck tile since they're not yet added now) ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5e06874aae
commit
22de6a19d9
@@ -36,7 +36,7 @@ repos:
|
||||
name: Run ck_tile remod.py
|
||||
entry: python projects/composablekernel/script/remod_for_ck_tile.py
|
||||
language: python
|
||||
files: '^(include|example)/ck_tile/.*$'
|
||||
files: '^projects/composablekernel/(include|example)/ck_tile/.*$'
|
||||
additional_dependencies:
|
||||
- dos2unix
|
||||
- clang-format==18.1.3
|
||||
|
||||
@@ -104,7 +104,7 @@ concept LdsTransferDescriptor = requires(T t) {
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.n_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.n_xdl_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
|
||||
@@ -6,24 +6,36 @@
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <utility>
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
/**
|
||||
* @file conv_algorithm_limits.hpp
|
||||
* @brief Compile-time validation concepts and helpers for convolution algorithm configurations
|
||||
*
|
||||
* This file provides C++20 concepts and compile-time validation functions for validating
|
||||
* block transfer configurations, memory access patterns, and hardware instruction constraints
|
||||
* in convolution algorithms.
|
||||
*
|
||||
* Key features:
|
||||
* - Vector transfer size validation for VMEM and LDS operations
|
||||
* - Access order permutation validation
|
||||
* - Thread cluster dimension validation
|
||||
* - Tile coverage validation for block transfers
|
||||
*/
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Limits for input vector transfer.
|
||||
template <auto Value>
|
||||
concept InputVectorTransferLimits = requires {
|
||||
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for input and output vector transfer (CK Tile).
|
||||
template <auto Value>
|
||||
concept TileInputOutputVectorTransferLimits =
|
||||
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
|
||||
@@ -174,13 +186,70 @@ constexpr auto get_mn_coverage()
|
||||
return mn;
|
||||
}
|
||||
|
||||
template <size_t DataTypeSize>
|
||||
constexpr auto get_data_max_vec_size()
|
||||
template <size_t N, DataType Type>
|
||||
constexpr bool IsVmemVectorSizeValid()
|
||||
{
|
||||
constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width();
|
||||
static_assert(max_vec_inst_size_bytes % DataTypeSize == 0,
|
||||
"The max vec instruction size is not a multiple of given data type size.");
|
||||
return max_vec_inst_size_bytes / DataTypeSize;
|
||||
using enum builder::DataType;
|
||||
// We have following type & VectorSize pair constraints.
|
||||
//-----------------------------------------------------------------------------------
|
||||
// (std::is_same_v<T, double> && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
// (std::is_same_v<T, float> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, fp16_t> &&
|
||||
// (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
// (std::is_same_v<T, bf16_t> &&
|
||||
// (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
|
||||
// (std::is_same_v<T, int32_t> &&
|
||||
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, fp8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, bf8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, int8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, e8m0_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, pk_int4_t> &&
|
||||
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
// (std::is_same_v<T, pk_fp4_raw_t> &&
|
||||
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
// (std::is_same_v<T, pk_fp4_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))
|
||||
//-----------------------------------------------------------------------------------
|
||||
// explicitly not using switch statement since we do not handle all possible data types
|
||||
// in DataType structure yet, so that I could cover all of them in `else` branch.
|
||||
if constexpr(Type == FP64)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 8;
|
||||
}
|
||||
else if constexpr(Type == FP32)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
|
||||
}
|
||||
else if constexpr(Type == I32)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
|
||||
}
|
||||
else if constexpr(Type == FP16 || Type == BF16)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32;
|
||||
}
|
||||
else if constexpr(Type == FP8 || Type == BF8)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
|
||||
}
|
||||
else if constexpr(Type == I8)
|
||||
{
|
||||
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(always_false<void>, "Unsupported memory instruction data type!");
|
||||
}
|
||||
}
|
||||
|
||||
// Valid LDS instruction bit sizes based on supported DS_READ/DS_WRITE operations
|
||||
// DS_READ_{B32,B64,B96,B128,U8,I8,U16,I16}
|
||||
// DS_WRITE_{B32,B64,B96,B128,B8,B16}
|
||||
template <size_t N, size_t DataTypeSize>
|
||||
constexpr bool IsLDSVectorSizeValid()
|
||||
{
|
||||
constexpr size_t bits = N * DataTypeSize * 8;
|
||||
return ck_tile::is_any_value_of(bits, 8, 16, 32, 64, 96, 128);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
@@ -217,52 +286,52 @@ concept ThreadsCoverCTile = requires {
|
||||
CBlockTransfer.scalar_per_vector) == 0;
|
||||
};
|
||||
|
||||
template <size_t Value>
|
||||
concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0);
|
||||
template <size_t N, DataType Type>
|
||||
concept IsVmemVectorSizeValid = detail::IsVmemVectorSizeValid<N, Type>();
|
||||
|
||||
template <size_t ScalarPerVec, size_t DataTypeSize>
|
||||
concept IsVectorSizeValid =
|
||||
IsPowerOf2<ScalarPerVec> && (ScalarPerVec <= detail::get_data_max_vec_size<DataTypeSize>());
|
||||
template <size_t N, size_t DataTypeSize>
|
||||
concept IsLDSVectorSizeValid = detail::IsLDSVectorSizeValid<N, DataTypeSize>();
|
||||
|
||||
// Composite concept for input block transfer validation (A)
|
||||
// Includes all validations: vector transfer limits, access order, cluster size,
|
||||
// vector size validity, and tile coverage
|
||||
template <auto A_BLOCK_TRANSFER,
|
||||
typename DataType,
|
||||
size_t BLOCK_SIZE,
|
||||
auto TILE_SIZE,
|
||||
size_t DIMS = 3>
|
||||
template <auto A_BlockTransfer,
|
||||
DataType Type,
|
||||
size_t TypeSize,
|
||||
size_t BlockSize,
|
||||
auto TileSize,
|
||||
size_t ThreadClusterRank = 3>
|
||||
concept ValidABlockTransfer =
|
||||
InputVectorTransferLimits<A_BLOCK_TRANSFER> &&
|
||||
AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
|
||||
AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order, DIMS> &&
|
||||
ValidBlockTransferClusterSize<A_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<A_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
|
||||
IsVectorSizeValid<A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverATile<A_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
InputVectorTransferLimits<A_BlockTransfer> &&
|
||||
AccessOrderLimits<A_BlockTransfer.thread_cluster_order, ThreadClusterRank> &&
|
||||
AccessOrderLimits<A_BlockTransfer.src_access_order, ThreadClusterRank> &&
|
||||
ValidBlockTransferClusterSize<A_BlockTransfer, BlockSize> &&
|
||||
IsVmemVectorSizeValid<A_BlockTransfer.src_scalar_per_vector, Type> &&
|
||||
IsLDSVectorSizeValid<A_BlockTransfer.lds_dst_scalar_per_vector, TypeSize> &&
|
||||
ThreadsCoverATile<A_BlockTransfer, TileSize>;
|
||||
|
||||
// Composite concept for input block transfer validation (B)
|
||||
template <auto B_BLOCK_TRANSFER,
|
||||
typename DataType,
|
||||
size_t BLOCK_SIZE,
|
||||
auto TILE_SIZE,
|
||||
size_t DIMS = 3>
|
||||
template <auto B_BlockTransfer,
|
||||
DataType Type,
|
||||
size_t TypeSize,
|
||||
size_t BlockSize,
|
||||
auto TileSize,
|
||||
size_t ThreadClusterRank = 3>
|
||||
concept ValidBBlockTransfer =
|
||||
InputVectorTransferLimits<B_BLOCK_TRANSFER> &&
|
||||
AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
|
||||
AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order, DIMS> &&
|
||||
ValidBlockTransferClusterSize<B_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<B_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
|
||||
IsVectorSizeValid<B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverBTile<B_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
InputVectorTransferLimits<B_BlockTransfer> &&
|
||||
AccessOrderLimits<B_BlockTransfer.thread_cluster_order, ThreadClusterRank> &&
|
||||
AccessOrderLimits<B_BlockTransfer.src_access_order, ThreadClusterRank> &&
|
||||
ValidBlockTransferClusterSize<B_BlockTransfer, BlockSize> &&
|
||||
IsVmemVectorSizeValid<B_BlockTransfer.src_scalar_per_vector, Type> &&
|
||||
IsLDSVectorSizeValid<B_BlockTransfer.lds_dst_scalar_per_vector, TypeSize> &&
|
||||
ThreadsCoverBTile<B_BlockTransfer, TileSize>;
|
||||
|
||||
// Composite concept for output block transfer validation (C)
|
||||
template <auto C_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
|
||||
concept ValidCBlockTransfer =
|
||||
OutputVectorTransferLimits<C_BLOCK_TRANSFER> &&
|
||||
ValidBlockTransferClusterSize<C_BLOCK_TRANSFER, BLOCK_SIZE> &&
|
||||
IsVectorSizeValid<C_BLOCK_TRANSFER.scalar_per_vector, sizeof(DataType)> &&
|
||||
ThreadsCoverCTile<C_BLOCK_TRANSFER, TILE_SIZE>;
|
||||
template <auto C_BlockTransfer, DataType Type, size_t BlockSize, auto TileSize>
|
||||
concept ValidCBlockTransfer = OutputVectorTransferLimits<C_BlockTransfer> &&
|
||||
ValidBlockTransferClusterSize<C_BlockTransfer, BlockSize> &&
|
||||
IsVmemVectorSizeValid<C_BlockTransfer.scalar_per_vector, Type> &&
|
||||
ThreadsCoverCTile<C_BlockTransfer, TileSize>;
|
||||
|
||||
// Usage: IsValidLayout<ACTUAL_LAYOUT, VALID_LAYOUT_1, VALID_LAYOUT_2, ...>
|
||||
template <auto ACTUAL_LAYOUT, auto... VALID_LAYOUTS>
|
||||
|
||||
@@ -48,15 +48,17 @@ struct ConvFwdLargeTensorFactory
|
||||
|
||||
// Check limits for the data transfer parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
Types::input_types.first,
|
||||
sizeof(typename Types::InDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
Types::weight_types.first,
|
||||
sizeof(typename Types::WeiDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
Types::output_types.first,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
|
||||
@@ -53,15 +53,17 @@ struct ConvFwdXdlV3Factory
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
Types::input_types.first,
|
||||
sizeof(typename Types::InDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
Types::weight_types.first,
|
||||
sizeof(typename Types::WeiDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
Types::output_types.first,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
|
||||
@@ -49,15 +49,17 @@ struct ConvFwdWmmaFactory
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
Types::input_types.first,
|
||||
sizeof(typename Types::InDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
Types::weight_types.first,
|
||||
sizeof(typename Types::WeiDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
Types::output_types.first,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
// TODO: verify Ds transfer as well
|
||||
|
||||
@@ -48,15 +48,20 @@ struct ConvFwdXdlFactory
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
|
||||
typename Types::InDataType,
|
||||
Types::input_types.first,
|
||||
sizeof(typename Types::InDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(A_BLOCK_TRANSFER.src_vector_dim == 2 ||
|
||||
(ALGORITHM.num_conv_groups_to_merge > 1 && A_BLOCK_TRANSFER.src_vector_dim == 1));
|
||||
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
|
||||
typename Types::WeiDataType,
|
||||
Types::weight_types.first,
|
||||
sizeof(typename Types::WeiDataType),
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
static_assert(B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
|
||||
typename Types::OutDataType,
|
||||
Types::output_types.first,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block>);
|
||||
|
||||
@@ -74,8 +79,7 @@ struct ConvFwdXdlFactory
|
||||
NDHWGC,
|
||||
NGCW,
|
||||
NGCHW,
|
||||
NGCDHW> &&
|
||||
A_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
NGCDHW>);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
|
||||
G_K_X_C_strided,
|
||||
@@ -89,8 +93,7 @@ struct ConvFwdXdlFactory
|
||||
KZYXGC,
|
||||
GKCX,
|
||||
GKCYX,
|
||||
GKCZYX> &&
|
||||
B_BLOCK_TRANSFER.src_vector_dim == 2);
|
||||
GKCZYX>);
|
||||
|
||||
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
|
||||
G_NW_K_strided,
|
||||
|
||||
@@ -112,7 +112,7 @@ constexpr CBlockTransfer SetCBlockTransfer()
|
||||
auto& epilogue_config = ALGORITHM.transfer.c.epilogue;
|
||||
return CBlockTransfer{
|
||||
.m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = epilogue_config.n_xdl_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
thread_cluster_dims.m_block,
|
||||
|
||||
@@ -65,35 +65,46 @@ consteval auto ConvertDataTypeToCK()
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetTensorDataAndComputeTypes()
|
||||
consteval auto ExtractTensorDataType()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
constexpr auto data_type = Config.data_type;
|
||||
|
||||
using enum DataType;
|
||||
|
||||
if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE)
|
||||
if constexpr(data_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else if constexpr(data_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
else if constexpr(compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
return SignatureDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
return data_type;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto ExtractTensorComputeType()
|
||||
{
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
|
||||
using enum DataType;
|
||||
if constexpr(compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return SignatureDataType;
|
||||
}
|
||||
else
|
||||
{
|
||||
return compute_type;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetTensorDataAndComputeTypes()
|
||||
{
|
||||
constexpr auto data_type = ExtractTensorDataType<Config, SignatureDataType>();
|
||||
constexpr auto compute_type = ExtractTensorComputeType<Config, SignatureDataType>();
|
||||
|
||||
return std::make_pair(data_type, compute_type);
|
||||
}
|
||||
|
||||
template <DataType SignatureAccDataType, DataType SignatureDataType>
|
||||
consteval auto GetTensorAccumulationType()
|
||||
{
|
||||
@@ -158,6 +169,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
|
||||
template <auto Signature>
|
||||
struct ConvTensorDataTypes
|
||||
{
|
||||
// Builder enumerator types
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
static constexpr auto weight_types =
|
||||
@@ -165,12 +177,12 @@ struct ConvTensorDataTypes
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using InDataType = typename decltype(input_types.first)::type;
|
||||
using InComputeType = typename decltype(input_types.second)::type;
|
||||
using WeiDataType = typename decltype(weight_types.first)::type;
|
||||
using WeiComputeType = typename decltype(weight_types.second)::type;
|
||||
using OutDataType = typename decltype(output_types.first)::type;
|
||||
using OutComputeType = typename decltype(output_types.second)::type;
|
||||
using InDataType = typename DataTypeToCK<input_types.first>::type;
|
||||
using InComputeType = typename DataTypeToCK<input_types.second>::type;
|
||||
using WeiDataType = typename DataTypeToCK<weight_types.first>::type;
|
||||
using WeiComputeType = typename DataTypeToCK<weight_types.second>::type;
|
||||
using OutDataType = typename DataTypeToCK<output_types.first>::type;
|
||||
using OutComputeType = typename DataTypeToCK<output_types.second>::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
|
||||
@@ -29,7 +29,7 @@ TEST(FwdConvInstances,
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_transfer(Transfer_4x16x1_asrc_vec_dim1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2);
|
||||
|
||||
@@ -31,7 +31,7 @@ TEST(FwdConvInstances,
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(ThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x32x1)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
|
||||
@@ -48,4 +48,81 @@ TEST(FwdConvInstances,
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst_LargeVecSize)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP32,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCDHW}},
|
||||
.weight = {.config = {.layout = GKCZYX}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1_Vec16{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 2, .m_n = 128, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 16,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 4,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 4},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams{
|
||||
.ak1 = 16,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams)
|
||||
.with_transfer(Transfer_4x64x1_Vec16)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"NGCDHW,GKCZYX,EmptyTuple,NGKDHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -133,7 +133,7 @@ static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
@@ -154,7 +154,7 @@ struct DefaultAlgorithm
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -78,7 +78,7 @@ constexpr Transfer<> Transfer_4x64x1{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 4},
|
||||
},
|
||||
};
|
||||
@@ -111,7 +111,7 @@ constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
@@ -144,7 +144,7 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
@@ -177,7 +177,7 @@ constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
@@ -210,12 +210,46 @@ constexpr Transfer<> Transfer_4x16x1{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x16x1_asrc_vec_dim1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 4,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {0, 2, 1},
|
||||
.src_access_order = {0, 2, 1},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 1,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.thread_cluster_arrange_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 1},
|
||||
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
@@ -244,7 +278,7 @@ constexpr Transfer<> Transfer_4x32x1{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -194,8 +194,8 @@ template <>
|
||||
inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
<< to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_xdl_per_wave_per_shuffle
|
||||
<< "," << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
import os
|
||||
|
||||
root_dir = os.getcwd()
|
||||
ck_tile_include = root_dir + "/include/ck_tile"
|
||||
ck_tile_example = root_dir + "/example/ck_tile"
|
||||
ck_tile_include = root_dir + "/projects/composablekernel/include/ck_tile"
|
||||
ck_tile_example = root_dir + "/projects/composablekernel/example/ck_tile"
|
||||
|
||||
# Run for include
|
||||
os.chdir(ck_tile_include)
|
||||
|
||||
Reference in New Issue
Block a user