fixed prefetch stage gemm

This commit is contained in:
Kevin Abraham
2026-01-13 08:16:04 +00:00
parent c8fda65534
commit 7b15c22e7e

View File

@@ -63,22 +63,22 @@ constexpr ConvTraits instance_to_conv_traits()
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}},
.warp_gemm = {.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector},
// .num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.warp_gemm = {.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat},
.c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle =
InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle =
InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector},
.num_gemm_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}