mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE] Grouped Conv Explicit Gemm (#3289)
* [CK TILE] Grouped Conv Explicit Gemm * fixes * apply builder fixes
This commit is contained in:
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvT
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// Split image (large tensors)
|
||||
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
// Explicit GEMM
|
||||
static constexpr int kExplicitGemm = GroupedConvTraitsType_::ExplicitGemm;
|
||||
|
||||
// TilePartitioner
|
||||
// Block configuration
|
||||
@@ -109,26 +111,27 @@ struct InstanceTraits<ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraits
|
||||
oss << "," << kVectorSizeC; // 9. VectorSizeC
|
||||
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
|
||||
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
|
||||
oss << "," << kMPerBlock; // 12. MPerBlock
|
||||
oss << "," << kNPerBlock; // 13. NPerBlock
|
||||
oss << "," << kKPerBlock; // 14. KPerBlock
|
||||
oss << "," << kMWarp; // 15. MWarp
|
||||
oss << "," << kNWarp; // 16. NWarp
|
||||
oss << "," << kKWarp; // 17. KWarp
|
||||
oss << "," << kMWarpTile; // 18. MWarpTile
|
||||
oss << "," << kNWarpTile; // 19. NWarpTile
|
||||
oss << "," << kKWarpTile; // 20. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
|
||||
oss << "," << kExplicitGemm; // 12. ExplicitGemm
|
||||
oss << "," << kMPerBlock; // 13. MPerBlock
|
||||
oss << "," << kNPerBlock; // 14. NPerBlock
|
||||
oss << "," << kKPerBlock; // 15. KPerBlock
|
||||
oss << "," << kMWarp; // 16. MWarp
|
||||
oss << "," << kNWarp; // 17. NWarp
|
||||
oss << "," << kKWarp; // 18. KWarp
|
||||
oss << "," << kMWarpTile; // 19. MWarpTile
|
||||
oss << "," << kNWarpTile; // 20. NWarpTile
|
||||
oss << "," << kKWarpTile; // 21. KWarpTile
|
||||
oss << "," << detail::type_name<ADataType>(); // 22. ADataType
|
||||
oss << "," << detail::type_name<BDataType>(); // 23. BDataType
|
||||
oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer
|
||||
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched
|
||||
oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer
|
||||
oss << "," << kNumWaveGroups; // 27. NumWaveGroups
|
||||
oss << "," << detail::type_name<AccDataType>(); // 28. AccDataType
|
||||
oss << "," << detail::type_name<EDataType>(); // 29. EDataType
|
||||
oss << "," << detail::tuple_name<DsDataType>(); // 30. DsDataType
|
||||
oss << ","
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
|
||||
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
|
||||
// CDEElementwiseOperation
|
||||
oss << ">";
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -106,6 +107,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
@@ -123,7 +123,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -208,6 +209,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
@@ -734,7 +734,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
4 /*VectorSizeB*/,
|
||||
4 /*VectorSizeC*/,
|
||||
1 /*NumGroupsToMerge*/,
|
||||
false /*EnableSplitImage*/>;
|
||||
false /*EnableSplitImage*/,
|
||||
false /*ExplicitGemm*/>;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
|
||||
@@ -818,6 +819,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
|
||||
",4" // VectorSizeC
|
||||
",1" // NumGroupsToMerge
|
||||
",0" // EnableSplitImage
|
||||
",0" // ExplicitGemm
|
||||
",128" // MPerBlock
|
||||
",128" // NPerBlock
|
||||
",32" // KPerBlock
|
||||
|
||||
Reference in New Issue
Block a user