fix the compv4 and async pipeline when tile handler is 1 (#3141)

This commit is contained in:
Thomas Ning
2025-11-03 09:37:35 -08:00
committed by GitHub
parent 2ec57a8e70
commit 057b7d43b4
2 changed files with 32 additions and 2 deletions

View File

@@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
@@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
@@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
}
else
else if(TailNum == TailNumber::Two)
// 2 block gemms remaining
{
{
@@ -500,6 +509,12 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};

View File

@@ -27,6 +27,10 @@ struct BaseGemmPipelineAgBgCrCompV4
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop == 1)
{
return TailNumber::One;
}
if(num_loop % PrefetchStages == 1)
{
return TailNumber::Three;
@@ -67,6 +71,11 @@ struct BaseGemmPipelineAgBgCrCompV4
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
else
{
return (run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::One>{}));
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
@@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0);
}
}
else
else if(TailNum == TailNumber::Two)
{
// 2
{
@@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
__builtin_amdgcn_sched_barrier(0);
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
}
return c_block_tile;
}
};