mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Refactor builder factory code. (#3276)
Refactor the builder factory code into multiple files and subdirectories and a ck_tile::builder::factory namespace. The factory implements compile-time dispatch from high-level signature and algorithm descriptors to our existing specialized convolution kernel implementations. Major changes in this PR: Dispatch logic is explicit in the function make_conv_instance instead of implicit in template specialization selection. Helper code is moved to a subdirectory builder/factory/helpers. Helpers now have unit tests. Factories are moved to their own files. Code moved to namespaces ck_tile::builder::factory and ck_tile::builder::factory::internal. This does not yet fix the problem of bad error messages, but the make_conv_instance function makes the poor error messages clear. The choice of algorithm must be much more robust (perhaps with explicit enumeration in the algorithm descriptor), so that the dispatch doesn't fail. Quality changes: Making dispatch explicit rather than implicit will improve robustness, readability, maintainability, testability, and extensibility. Separating code into separate files and subdirectories helps readability and extensibility. Adding unit tests for helpers documents behavior and will enable more complex logic and functionality. Separating files (especially unit tests) helps clarify includes and dependencies and makes code easier to refactor.
This commit is contained in:
@@ -67,11 +67,11 @@ struct DefaultAlgorithm
|
||||
ckb::test::TransferABC transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
@@ -79,11 +79,11 @@ struct DefaultAlgorithm
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
@@ -92,9 +92,9 @@ struct DefaultAlgorithm
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -144,22 +144,22 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 2\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" └─ Vector access (GMEM write) instruction size: 8"));
|
||||
" └─ Vector access (GMEM write) instruction size: 2"));
|
||||
}
|
||||
|
||||
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
|
||||
|
||||
Reference in New Issue
Block a user