mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
fix the compv4 and async pipeline when tile handler is 1 (#3141)
This commit is contained in:
@@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
@@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(TailNum == TailNumber::Two)
|
||||
// 2 block gemms remaining
|
||||
{
|
||||
{
|
||||
@@ -500,6 +509,12 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -27,6 +27,10 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
@@ -67,6 +71,11 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(TailNum == TailNumber::Two)
|
||||
{
|
||||
// 2
|
||||
{
|
||||
@@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user