moved common attributes to helpers

This commit is contained in:
Kevin Abraham
2026-01-13 13:39:58 +00:00
parent 7b15c22e7e
commit 9b1c8ae951
4 changed files with 108 additions and 138 deletions

View File

@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}

View File

@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}

View File

@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = {.m = InstTraits::kMPerBlock,
.n = InstTraits::kNPerBlock,
.k = InstTraits::kKPerBlock},
.a_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
.b_tile_transfer =
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector =
InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}

View File

@@ -640,4 +640,91 @@ constexpr auto get_pipeline_scheduler()
}
}
template <typename InstTraits>
constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
}
template <typename InstTraits>
constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params()
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat};
}
template <typename InstTraits>
constexpr DataTileInfo conv_traits_data_tile()
{
return DataTileInfo{
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
}
} // namespace ck_tile::reflect::conv