mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
Add IsSupportedArgument() to gemm_kernel (#1698)
* add IsSupportedArgument to gemm_kernel * add ut and do some refactoring * switched to ck_tile's integral_constant
This commit is contained in:
@@ -66,6 +66,79 @@ struct GemmKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::VectorSizeA != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::VectorSizeA != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::VectorSizeB != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::VectorSizeB != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::kN != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::VectorSizeC != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::kM != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::VectorSizeC != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}();
|
||||
|
||||
Reference in New Issue
Block a user