[CK_TILE] Merge multiple fwd convolution groups into a single GEMM batch. (#3136)

* Merge fwd conv groups in CK Tile.

* Fix building CK fwd convs.

* Add number of merged groups to conv fwd kernel name.

* Get number of merged groups from conv config.

* Rename GemmConfig to ConvConfig.

* Clean-up TODOs.

* Check that number of conv groups must be divisible by the number of merged groups.

* Improve error handling in the conv fwd example.

* Fix clang-format.

* Fix group offsets.

* Fix merge problem.

* Address feedback from code review.

* Fix clang-formatting.
This commit is contained in:
Ville Pietilä
2025-12-02 15:23:32 +02:00
committed by GitHub
parent 2d3020e5b0
commit 66832861ad
4 changed files with 111 additions and 58 deletions

View File

@@ -50,9 +50,17 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}