mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
ck-builder: group transfer operations per tensor (#3217)
Grouping transfer operations per tensor makes it easier to constrain on and operate with the transfer operations. As an example, we can now deduplicate the logic for translating the transfer operations from the ck-builder interface to the old ck interface for the A and B tensors.
This commit is contained in:
@@ -64,30 +64,39 @@ struct DefaultAlgorithm
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
ckb::test::BlockTransferABC block_transfer{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {.order = {0, 1, 2}},
|
||||
.block_transfer_access_order_b = {.order = {0, 1, 2}},
|
||||
.src_access_order_a = {.order = {0, 1, 2}},
|
||||
.src_access_order_b = {.order = {0, 1, 2}}};
|
||||
ckb::test::TransferABC transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
|
||||
Reference in New Issue
Block a user