mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
moved common attributes to helpers
This commit is contained in:
@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits()
|
||||
.gemm_padding = gemm_spec<Instance>(),
|
||||
.conv_specialization = conv_spec<Instance>(),
|
||||
.thread_block_size = InstTraits::kBlockSize,
|
||||
.tile_dims = {.m = InstTraits::kMPerBlock,
|
||||
.n = InstTraits::kNPerBlock,
|
||||
.k = InstTraits::kKPerBlock},
|
||||
.a_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}},
|
||||
.b_tile_transfer =
|
||||
{.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector =
|
||||
InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
|
||||
.warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave},
|
||||
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle =
|
||||
InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector},
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
.tile_dims = conv_traits_data_tile<InstTraits>(),
|
||||
.a_tile_transfer = conv_traits_xdl_a_transfer_params<InstTraits>(),
|
||||
.b_tile_transfer = conv_traits_xdl_b_transfer_params<InstTraits>(),
|
||||
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
|
||||
.c_tile_transfer = conv_traits_xdl_c_tile_transfer<InstTraits>(),
|
||||
.pipeline_version = get_pipeline_version<InstTraits>(),
|
||||
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -640,4 +640,91 @@ constexpr auto get_pipeline_scheduler()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params()
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params()
|
||||
{
|
||||
return InputTileTransferInfo{
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer()
|
||||
{
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
|
||||
{
|
||||
return OutputTileTransferInfo{
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
|
||||
{
|
||||
return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
|
||||
{
|
||||
return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma,
|
||||
.gemm_n = InstTraits::kNPerWmma,
|
||||
.m_iter = InstTraits::kMRepeat,
|
||||
.n_iter = InstTraits::kNRepeat};
|
||||
}
|
||||
|
||||
template <typename InstTraits>
|
||||
constexpr DataTileInfo conv_traits_data_tile()
|
||||
{
|
||||
return DataTileInfo{
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
|
||||
Reference in New Issue
Block a user