[CK TILE] Grouped Conv Explicit Gemm (#3289)

* [CK TILE] Grouped Conv Explicit Gemm

* fixes

* apply builder fixes
This commit is contained in:
Bartłomiej Kocot
2025-11-25 23:28:35 +01:00
committed by GitHub
parent 37ea160088
commit 00dfa2f2ce
13 changed files with 386 additions and 269 deletions

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// Explicit GEMM
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
// TilePartitioner
// Block configuration
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << "," << kExplicitGemm; // 12. ExplicitGemm
oss << "," << kMPerBlock; // 13. MPerBlock
oss << "," << kNPerBlock; // 14. NPerBlock
oss << "," << kKPerBlock; // 15. KPerBlock
oss << "," << kMWarp; // 16. MWarp
oss << "," << kNWarp; // 17. NWarp
oss << "," << kKWarp; // 18. KWarp
oss << "," << kMWarpTile; // 19. MWarpTile
oss << "," << kNWarpTile; // 20. NWarpTile
oss << "," << kKWarpTile; // 21. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << ">";

View File

@@ -21,7 +21,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -106,6 +107,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock

View File

@@ -123,7 +123,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -208,6 +209,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock

View File

@@ -734,7 +734,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
@@ -818,6 +819,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock