This commit is contained in:
lalala-sh
2025-05-08 09:48:23 +00:00
parent abff33eaab
commit 960b2bce1c

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -474,7 +474,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
@@ -553,7 +553,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -579,7 +579,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -605,7 +605,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
@@ -680,7 +680,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -700,7 +700,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -720,7 +720,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
@@ -769,7 +769,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
@@ -824,7 +824,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(