[rocm-libraries] ROCm/rocm-libraries#4750 (commit c065793)

[CK_BUILDER] ck builder conv transfer fix

## Motivation

This PR fixes how CK Builder is validating transfer vector size and adds
proper validation for LDS transfer vector size as well.

## Changes:

* [__source vector dim__] -- Before this PR the data transfer validation
logic didn't allow to set the source vectorized dimension to 1. However
there are CK instances that are doing this when the group merging is
used. This is used only for
`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle` kernel.
* [__valid vector size__] -- Before this PR the validation logic
concerned only single instruction maximum vector size. However our
buffer loading logic has implemented support for loading more values
through multiple buffer instructions. This again was discovered to be
used in some of the convolution instances. Thus this behavior was
reflected in validation logic.
* [__valid LDS vector size__] -- Before this PR the LDS vector size
validation was done in the same way as VMEM. This PR adds proper LDS
vector size validation based on the available LDS instruction sizes.

## Test Plan

Run CK BUILDER conv fwd factories tests

## Test Result

All CK BUILDER conv fwd factories work (except DL one & ck tile since
they're not yet added now)

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Adam Osewski
2026-02-27 13:48:58 +00:00
committed by assistant-librarian[bot]
parent 5e06874aae
commit 22de6a19d9
17 changed files with 304 additions and 103 deletions

View File

@@ -36,7 +36,7 @@ repos:
name: Run ck_tile remod.py
entry: python projects/composablekernel/script/remod_for_ck_tile.py
language: python
files: '^(include|example)/ck_tile/.*$'
files: '^projects/composablekernel/(include|example)/ck_tile/.*$'
additional_dependencies:
- dos2unix
- clang-format==18.1.3

View File

@@ -104,7 +104,7 @@ concept LdsTransferDescriptor = requires(T t) {
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
{ t.n_per_wave_per_shuffle } -> SizeType;
{ t.n_xdl_per_wave_per_shuffle } -> SizeType;
{ t.scalar_per_vector } -> SizeType;
};

View File

@@ -6,24 +6,36 @@
#include <type_traits>
#include <concepts>
#include <utility>
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
/**
* @file conv_algorithm_limits.hpp
* @brief Compile-time validation concepts and helpers for convolution algorithm configurations
*
* This file provides C++20 concepts and compile-time validation functions for validating
* block transfer configurations, memory access patterns, and hardware instruction constraints
* in convolution algorithms.
*
* Key features:
* - Vector transfer size validation for VMEM and LDS operations
* - Access order permutation validation
* - Thread cluster dimension validation
* - Tile coverage validation for block transfers
*/
namespace ck_tile::builder {
// Limits for input vector transfer.
template <auto Value>
concept InputVectorTransferLimits = requires {
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for input and output vector transfer (CK Tile).
template <auto Value>
concept TileInputOutputVectorTransferLimits =
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
@@ -174,13 +186,70 @@ constexpr auto get_mn_coverage()
return mn;
}
template <size_t DataTypeSize>
constexpr auto get_data_max_vec_size()
template <size_t N, DataType Type>
constexpr bool IsVmemVectorSizeValid()
{
constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width();
static_assert(max_vec_inst_size_bytes % DataTypeSize == 0,
"The max vec instruction size is not a multiple of given data type size.");
return max_vec_inst_size_bytes / DataTypeSize;
using enum builder::DataType;
// We have following type & VectorSize pair constraints.
//-----------------------------------------------------------------------------------
// (std::is_same_v<T, double> && (N == 1 || N == 2 || N == 4 || N == 8)) ||
// (std::is_same_v<T, float> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, fp16_t> &&
// (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
// (std::is_same_v<T, bf16_t> &&
// (N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32)) ||
// (std::is_same_v<T, int32_t> &&
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, fp8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, bf8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, int8_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, e8m0_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, pk_int4_t> &&
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
// (std::is_same_v<T, pk_fp4_raw_t> &&
// (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
// (std::is_same_v<T, pk_fp4_t> && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))
//-----------------------------------------------------------------------------------
// explicitly not using switch statement since we do not handle all possible data types
// in DataType structure yet, so that I could cover all of them in `else` branch.
if constexpr(Type == FP64)
{
return N == 1 || N == 2 || N == 4 || N == 8;
}
else if constexpr(Type == FP32)
{
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
}
else if constexpr(Type == I32)
{
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
}
else if constexpr(Type == FP16 || Type == BF16)
{
return N == 1 || N == 2 || N == 4 || N == 6 || N == 8 || N == 16 || N == 32;
}
else if constexpr(Type == FP8 || Type == BF8)
{
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
}
else if constexpr(Type == I8)
{
return N == 1 || N == 2 || N == 4 || N == 8 || N == 16;
}
else
{
static_assert(always_false<void>, "Unsupported memory instruction data type!");
}
}
// Valid LDS instruction bit sizes based on supported DS_READ/DS_WRITE operations
// DS_READ_{B32,B64,B96,B128,U8,I8,U16,I16}
// DS_WRITE_{B32,B64,B96,B128,B8,B16}
template <size_t N, size_t DataTypeSize>
constexpr bool IsLDSVectorSizeValid()
{
constexpr size_t bits = N * DataTypeSize * 8;
return ck_tile::is_any_value_of(bits, 8, 16, 32, 64, 96, 128);
}
} // namespace detail
@@ -217,52 +286,52 @@ concept ThreadsCoverCTile = requires {
CBlockTransfer.scalar_per_vector) == 0;
};
template <size_t Value>
concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0);
template <size_t N, DataType Type>
concept IsVmemVectorSizeValid = detail::IsVmemVectorSizeValid<N, Type>();
template <size_t ScalarPerVec, size_t DataTypeSize>
concept IsVectorSizeValid =
IsPowerOf2<ScalarPerVec> && (ScalarPerVec <= detail::get_data_max_vec_size<DataTypeSize>());
template <size_t N, size_t DataTypeSize>
concept IsLDSVectorSizeValid = detail::IsLDSVectorSizeValid<N, DataTypeSize>();
// Composite concept for input block transfer validation (A)
// Includes all validations: vector transfer limits, access order, cluster size,
// vector size validity, and tile coverage
template <auto A_BLOCK_TRANSFER,
typename DataType,
size_t BLOCK_SIZE,
auto TILE_SIZE,
size_t DIMS = 3>
template <auto A_BlockTransfer,
DataType Type,
size_t TypeSize,
size_t BlockSize,
auto TileSize,
size_t ThreadClusterRank = 3>
concept ValidABlockTransfer =
InputVectorTransferLimits<A_BLOCK_TRANSFER> &&
AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order, DIMS> &&
ValidBlockTransferClusterSize<A_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverATile<A_BLOCK_TRANSFER, TILE_SIZE>;
InputVectorTransferLimits<A_BlockTransfer> &&
AccessOrderLimits<A_BlockTransfer.thread_cluster_order, ThreadClusterRank> &&
AccessOrderLimits<A_BlockTransfer.src_access_order, ThreadClusterRank> &&
ValidBlockTransferClusterSize<A_BlockTransfer, BlockSize> &&
IsVmemVectorSizeValid<A_BlockTransfer.src_scalar_per_vector, Type> &&
IsLDSVectorSizeValid<A_BlockTransfer.lds_dst_scalar_per_vector, TypeSize> &&
ThreadsCoverATile<A_BlockTransfer, TileSize>;
// Composite concept for input block transfer validation (B)
template <auto B_BLOCK_TRANSFER,
typename DataType,
size_t BLOCK_SIZE,
auto TILE_SIZE,
size_t DIMS = 3>
template <auto B_BlockTransfer,
DataType Type,
size_t TypeSize,
size_t BlockSize,
auto TileSize,
size_t ThreadClusterRank = 3>
concept ValidBBlockTransfer =
InputVectorTransferLimits<B_BLOCK_TRANSFER> &&
AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order, DIMS> &&
ValidBlockTransferClusterSize<B_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverBTile<B_BLOCK_TRANSFER, TILE_SIZE>;
InputVectorTransferLimits<B_BlockTransfer> &&
AccessOrderLimits<B_BlockTransfer.thread_cluster_order, ThreadClusterRank> &&
AccessOrderLimits<B_BlockTransfer.src_access_order, ThreadClusterRank> &&
ValidBlockTransferClusterSize<B_BlockTransfer, BlockSize> &&
IsVmemVectorSizeValid<B_BlockTransfer.src_scalar_per_vector, Type> &&
IsLDSVectorSizeValid<B_BlockTransfer.lds_dst_scalar_per_vector, TypeSize> &&
ThreadsCoverBTile<B_BlockTransfer, TileSize>;
// Composite concept for output block transfer validation (C)
template <auto C_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
concept ValidCBlockTransfer =
OutputVectorTransferLimits<C_BLOCK_TRANSFER> &&
ValidBlockTransferClusterSize<C_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<C_BLOCK_TRANSFER.scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverCTile<C_BLOCK_TRANSFER, TILE_SIZE>;
template <auto C_BlockTransfer, DataType Type, size_t BlockSize, auto TileSize>
concept ValidCBlockTransfer = OutputVectorTransferLimits<C_BlockTransfer> &&
ValidBlockTransferClusterSize<C_BlockTransfer, BlockSize> &&
IsVmemVectorSizeValid<C_BlockTransfer.scalar_per_vector, Type> &&
ThreadsCoverCTile<C_BlockTransfer, TileSize>;
// Usage: IsValidLayout<ACTUAL_LAYOUT, VALID_LAYOUT_1, VALID_LAYOUT_2, ...>
template <auto ACTUAL_LAYOUT, auto... VALID_LAYOUTS>

View File

@@ -48,15 +48,17 @@ struct ConvFwdLargeTensorFactory
// Check limits for the data transfer parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::InDataType,
Types::input_types.first,
sizeof(typename Types::InDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::WeiDataType,
Types::weight_types.first,
sizeof(typename Types::WeiDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::OutDataType,
Types::output_types.first,
BLOCK.block_size,
BLOCK.per_block>);

View File

@@ -53,15 +53,17 @@ struct ConvFwdXdlV3Factory
// Check limits for the algorithm parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::InDataType,
Types::input_types.first,
sizeof(typename Types::InDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::WeiDataType,
Types::weight_types.first,
sizeof(typename Types::WeiDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::OutDataType,
Types::output_types.first,
BLOCK.block_size,
BLOCK.per_block>);

View File

@@ -49,15 +49,17 @@ struct ConvFwdWmmaFactory
// Check limits for the algorithm parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::InDataType,
Types::input_types.first,
sizeof(typename Types::InDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::WeiDataType,
Types::weight_types.first,
sizeof(typename Types::WeiDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::OutDataType,
Types::output_types.first,
BLOCK.block_size,
BLOCK.per_block>);
// TODO: verify Ds transfer as well

View File

@@ -48,15 +48,20 @@ struct ConvFwdXdlFactory
// Check limits for the algorithm parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::InDataType,
Types::input_types.first,
sizeof(typename Types::InDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(A_BLOCK_TRANSFER.src_vector_dim == 2 ||
(ALGORITHM.num_conv_groups_to_merge > 1 && A_BLOCK_TRANSFER.src_vector_dim == 1));
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::WeiDataType,
Types::weight_types.first,
sizeof(typename Types::WeiDataType),
BLOCK.block_size,
BLOCK.per_block>);
static_assert(B_BLOCK_TRANSFER.src_vector_dim == 2);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::OutDataType,
Types::output_types.first,
BLOCK.block_size,
BLOCK.per_block>);
@@ -74,8 +79,7 @@ struct ConvFwdXdlFactory
NDHWGC,
NGCW,
NGCHW,
NGCDHW> &&
A_BLOCK_TRANSFER.src_vector_dim == 2);
NGCDHW>);
static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
G_K_X_C_strided,
@@ -89,8 +93,7 @@ struct ConvFwdXdlFactory
KZYXGC,
GKCX,
GKCYX,
GKCZYX> &&
B_BLOCK_TRANSFER.src_vector_dim == 2);
GKCZYX>);
static_assert(IsValidLayout<SIGNATURE.output.config.layout,
G_NW_K_strided,

View File

@@ -112,7 +112,7 @@ constexpr CBlockTransfer SetCBlockTransfer()
auto& epilogue_config = ALGORITHM.transfer.c.epilogue;
return CBlockTransfer{
.m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle,
.n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle,
.n_xdl_per_wave_per_shuffle = epilogue_config.n_xdl_per_wave_per_shuffle,
.thread_cluster_dims =
{
thread_cluster_dims.m_block,

View File

@@ -65,35 +65,46 @@ consteval auto ConvertDataTypeToCK()
}
template <auto Config, DataType SignatureDataType>
consteval auto GetTensorDataAndComputeTypes()
consteval auto ExtractTensorDataType()
{
constexpr auto data_type = Config.data_type;
constexpr auto compute_type = Config.compute_type;
constexpr auto data_type = Config.data_type;
using enum DataType;
if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE)
if constexpr(data_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
ConvertDataTypeToCK<SignatureDataType>());
}
else if constexpr(data_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
ConvertDataTypeToCK<compute_type>());
}
else if constexpr(compute_type == UNDEFINED_DATA_TYPE)
{
return std::make_pair(ConvertDataTypeToCK<data_type>(),
ConvertDataTypeToCK<SignatureDataType>());
return SignatureDataType;
}
else
{
return std::make_pair(ConvertDataTypeToCK<data_type>(),
ConvertDataTypeToCK<compute_type>());
return data_type;
}
}
template <auto Config, DataType SignatureDataType>
consteval auto ExtractTensorComputeType()
{
constexpr auto compute_type = Config.compute_type;
using enum DataType;
if constexpr(compute_type == UNDEFINED_DATA_TYPE)
{
return SignatureDataType;
}
else
{
return compute_type;
}
}
template <auto Config, DataType SignatureDataType>
consteval auto GetTensorDataAndComputeTypes()
{
constexpr auto data_type = ExtractTensorDataType<Config, SignatureDataType>();
constexpr auto compute_type = ExtractTensorComputeType<Config, SignatureDataType>();
return std::make_pair(data_type, compute_type);
}
template <DataType SignatureAccDataType, DataType SignatureDataType>
consteval auto GetTensorAccumulationType()
{
@@ -158,6 +169,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
template <auto Signature>
struct ConvTensorDataTypes
{
// Builder enumerator types
static constexpr auto input_types =
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
static constexpr auto weight_types =
@@ -165,12 +177,12 @@ struct ConvTensorDataTypes
static constexpr auto output_types =
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
using InDataType = typename decltype(input_types.first)::type;
using InComputeType = typename decltype(input_types.second)::type;
using WeiDataType = typename decltype(weight_types.first)::type;
using WeiComputeType = typename decltype(weight_types.second)::type;
using OutDataType = typename decltype(output_types.first)::type;
using OutComputeType = typename decltype(output_types.second)::type;
using InDataType = typename DataTypeToCK<input_types.first>::type;
using InComputeType = typename DataTypeToCK<input_types.second>::type;
using WeiDataType = typename DataTypeToCK<weight_types.first>::type;
using WeiComputeType = typename DataTypeToCK<weight_types.second>::type;
using OutDataType = typename DataTypeToCK<output_types.first>::type;
using OutComputeType = typename DataTypeToCK<output_types.second>::type;
using AccDataType =
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
Signature.data_type>())::type;

View File

@@ -29,7 +29,7 @@ TEST(FwdConvInstances,
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_transfer(Transfer_4x16x1_asrc_vec_dim1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2);

View File

@@ -31,7 +31,7 @@ TEST(FwdConvInstances,
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
.with_thread_block(ThreadBlock_128_64x64x64)
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
.with_transfer(Transfer_4x32x1)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2)

View File

@@ -48,4 +48,81 @@ TEST(FwdConvInstances,
"MNKPadding"});
}
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
TEST(
FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst_LargeVecSize)
{
using enum ck_tile::builder::ConvDirection;
using enum ck_tile::builder::DataType;
using enum ck_tile::builder::TensorLayout;
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
.direction = FORWARD,
.data_type = FP32,
.accumulation_data_type = FP32,
.input = {.config = {.layout = NGCDHW}},
.weight = {.config = {.layout = GKCZYX}},
.output = {.config = {.layout = NGKDHW}}};
constexpr Transfer<> Transfer_4x64x1_Vec16{
.a =
{
.block_transfer = {.k0 = 2, .m_n = 128, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = false},
.thread_cluster_arrange_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 4,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = false},
.thread_cluster_arrange_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 4},
},
};
constexpr GridwiseFwdXdlGemm FwdGemmParams{
.ak1 = 16,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams)
.with_transfer(Transfer_4x64x1_Vec16)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
expected_transfer_parameters,
"Filter1x1Pad0",
"Intrawave",
"v1",
"NGCDHW,GKCZYX,EmptyTuple,NGKDHW",
"PassThrough,PassThrough,PassThrough",
"MNKPadding"});
}
} // namespace

View File

@@ -133,7 +133,7 @@ static_assert(LdsTransferDescriptor<LdsTransfer>);
struct Epilogue
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t scalar_per_vector;
};
static_assert(EpilogueDescriptor<Epilogue>);

View File

@@ -154,7 +154,7 @@ struct DefaultAlgorithm
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 2},
},
};

View File

@@ -78,7 +78,7 @@ constexpr Transfer<> Transfer_4x64x1{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 4},
},
};
@@ -111,7 +111,7 @@ constexpr Transfer<4> BwdTransfer_4x64x1{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
@@ -144,7 +144,7 @@ constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 2},
},
};
@@ -177,7 +177,7 @@ constexpr Transfer<> Transfer_4x64x1_fp8{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
@@ -210,12 +210,46 @@ constexpr Transfer<> Transfer_4x16x1{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr Transfer<> Transfer_4x16x1_asrc_vec_dim1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 1,
.src_scalar_per_vector = 4,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.thread_cluster_arrange_order = {0, 2, 1},
.src_access_order = {0, 2, 1},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 1,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.thread_cluster_arrange_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 1},
},
};
constexpr Transfer<> Transfer_4x32x1{
.a =
{
@@ -244,7 +278,7 @@ constexpr Transfer<> Transfer_4x32x1{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};

View File

@@ -194,8 +194,8 @@ template <>
inline std::string to_string<OutputTransfer>(OutputTransfer t)
{
std::ostringstream oss;
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
<< to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector;
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_xdl_per_wave_per_shuffle
<< "," << to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector;
return oss.str();
}

View File

@@ -4,8 +4,8 @@
import os
root_dir = os.getcwd()
ck_tile_include = root_dir + "/include/ck_tile"
ck_tile_example = root_dir + "/example/ck_tile"
ck_tile_include = root_dir + "/projects/composablekernel/include/ck_tile"
ck_tile_example = root_dir + "/projects/composablekernel/example/ck_tile"
# Run for include
os.chdir(ck_tile_include)