mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
ck-builder: group transfer operations per tensor (#3217)
Grouping transfer operations per tensor makes it easier to
constrain on and operate with the transfer operations. As an
example, we can now deduplicate the logic for translating
the transfer operations from the ck-builder interface to the old
ck interface for the A and B tensors.
[ROCm/composable_kernel commit: 245c6011cf]
This commit is contained in:
@@ -64,30 +64,39 @@ struct DefaultAlgorithm
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
ckb::test::BlockTransferABC block_transfer{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {.order = {0, 1, 2}},
|
||||
.block_transfer_access_order_b = {.order = {0, 1, 2}},
|
||||
.src_access_order_a = {.order = {0, 1, 2}},
|
||||
.src_access_order_b = {.order = {0, 1, 2}}};
|
||||
ckb::test::TransferABC transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {.order = {0, 1, 2}},
|
||||
.src_access_order = {.order = {0, 1, 2}},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
|
||||
Reference in New Issue
Block a user