Build conv bwd weigth v3 instances successfully.

This commit is contained in:
Ville Pietilä
2025-12-29 09:30:58 -05:00
parent 80f44824f5
commit a83790e9da
4 changed files with 32 additions and 29 deletions

View File

@@ -65,7 +65,7 @@ concept BlockTransferDescriptor = requires(T t) {
};
template <typename T>
concept BlockTransferDescriptorBwd = requires(T t) {
concept BlockTransferDescriptor4D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
@@ -211,11 +211,12 @@ concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info (Bwd direction).
// Concept to check if a struct specifies convolution input and output block transfer info
// for 4D thread slices.
template <typename T>
concept SpecifiesBlockTransferBwd = requires(T t) {
{ T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd;
concept SpecifiesBlockTransfer4D = requires(T t) {
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor4D;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor4D;
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};

View File

@@ -184,9 +184,9 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string {
return msg;
}
// BlockTransferDescriptorBwd diagnostics (requires k_batch_size)
// BlockTransferDescriptor4D diagnostics (requires k_batch_size)
template <typename T, typename BT>
consteval auto diagnose_block_transfer_bwd(const char* prefix) -> std::string {
consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string {
std::string msg;
if constexpr (requires(BT t) { t.k0; }) {
@@ -500,7 +500,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string {
}
template <typename T>
consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string {
consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string {
std::string msg;
constexpr bool has_transfer = requires { T::transfer; };
@@ -510,16 +510,20 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string {
return msg;
}
constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; };
constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; };
msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n";
if constexpr (!has_a) {
msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n";
} else {
msg += detail::diagnose_block_transfer_4d<T, decltype(T::transfer.a.block_transfer)>("transfer.a.block_transfer");
}
constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; };
constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; };
msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n";
if constexpr (!has_b) {
msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n";
} else {
msg += detail::diagnose_block_transfer_4d<T, decltype(T::transfer.b.block_transfer)>("transfer.b.block_transfer");
}
constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; };

View File

@@ -242,7 +242,7 @@ template <typename T>
struct BwdXdlAlgorithm {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransferBwd)
CHECK_CONCEPT(T, SpecifiesBlockTransfer4D)
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
@@ -252,7 +252,7 @@ struct BwdXdlAlgorithm {
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
static constexpr bool c3 = c_SpecifiesBlockTransferBwd;
static constexpr bool c3 = c_SpecifiesBlockTransfer4D;
static constexpr bool c4 = c_SpecifiesLdsTransfer;
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
@@ -269,7 +269,7 @@ struct BwdXdlAlgorithm {
"Concepts for BwdXdl Algorithm:\n") +
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) +
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
@@ -283,7 +283,7 @@ template <typename T>
struct BwdXdlV3Algorithm {
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
CHECK_CONCEPT(T, SpecifiesThreadBlock)
CHECK_CONCEPT(T, SpecifiesBlockTransferBwd)
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
@@ -293,7 +293,7 @@ struct BwdXdlV3Algorithm {
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
static constexpr bool c2 = c_SpecifiesThreadBlock;
static constexpr bool c3 = c_SpecifiesBlockTransferBwd;
static constexpr bool c3 = c_SpecifiesBlockTransfer;
static constexpr bool c4 = c_SpecifiesLdsTransfer;
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
@@ -310,7 +310,7 @@ struct BwdXdlV3Algorithm {
"Concepts for BwdXdlV3 Algorithm:\n") +
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) +
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +

View File

@@ -12,9 +12,9 @@ namespace ck_tile::builder::factory::internal {
// Block transfer parameters for A or B tensor.
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order{};
ck::Array<size_t, 3> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
@@ -22,15 +22,15 @@ struct BlockTransfer
bool lds_padding = false;
};
template <size_t ThreadSliceDim = 3>
struct BwdBlockTransfer
{
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
ck::Array<size_t, 4> thread_cluster_order = {0, 0, 0, 0};
ck::Array<size_t, 4> src_access_order = {0, 0, 0, 0};
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
ck::Array<size_t, ThreadSliceDim> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
};
@@ -55,7 +55,7 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
}
template <auto TRANSFER>
constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
constexpr auto SetBwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
@@ -68,27 +68,25 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
if constexpr (array_length == 3)
{
return BwdBlockTransfer{
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
return BwdBlockTransfer<3>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.is_direct_load = lds_cfg.is_direct_load,
.lds_padding = lds_cfg.lds_padding,
};
}
else if constexpr (array_length == 4)
{
return BwdBlockTransfer{
return BwdBlockTransfer<4>{
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.is_direct_load = lds_cfg.is_direct_load,
.lds_padding = lds_cfg.lds_padding,
};
}