mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK TILE] Grouped Conv Explicit Gemm (#3289)
* [CK TILE] Grouped Conv Explicit Gemm * fixes * apply builder fixes
This commit is contained in:
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user