fix bugs, build passed

This commit is contained in:
mtgu0705
2025-05-11 17:49:43 +08:00
parent d3f007c775
commit 0bddd63d9c
2 changed files with 13 additions and 12 deletions

View File

@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t B_K1 =
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};

View File

@@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
using Base::KThreadChunk;
using Base::ComputePackedSize;
using Base::APackedSize;
using Base::BPackedSize;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
@@ -410,7 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -489,10 +491,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -516,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
@@ -645,9 +647,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -671,7 +673,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
@@ -717,9 +719,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -770,9 +772,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx<BlockGemmPipelineScheduler
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));