ck-builder: group transfer operations per tensor (#3217)

Grouping transfer operations per tensor makes it easier to
constrain on and operate with the transfer operations. As an
example, we can now deduplicate the logic for translating
the transfer operations from the ck-builder interface to the old
ck interface for the A and B tensors.
This commit is contained in:
Robin Voetter
2025-11-20 19:40:48 +01:00
committed by GitHub
parent fb43760c66
commit 245c6011cf
18 changed files with 280 additions and 242 deletions

View File

@@ -25,109 +25,152 @@ constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8,
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
constexpr DlEpilogue DlEpilogueC{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4};
constexpr DlTransferABC DlFwdTransfer{.a =
{
.block_transfer = DlBlockTransferAB,
},
.b =
{
.block_transfer = DlBlockTransferAB,
},
.c = {
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4},
}};
constexpr BlockTransferABC FwdBlockTransfer_4x64x1{
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr TransferABC FwdTransfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr BlockTransferABC FwdBlockTransfer_4x64x1_fp8{
.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr TransferABC FwdTransfer_4x64x1_fp8{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr BlockTransferABC FwdBlockTransfer_4x16x1{
.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr TransferABC FwdTransfer_4x16x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4},
.epilogue = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
constexpr BlockTransferABC FwdBlockTransfer_4x32x1{
.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
},
};
constexpr TransferABC FwdTransfer_4x32x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 16,
.lds_dst_scalar_per_vector = 16,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4},
.epilogue = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};