mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Build conv bwd weigth v3 instances successfully.
This commit is contained in:
@@ -65,7 +65,7 @@ concept BlockTransferDescriptor = requires(T t) {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptorBwd = requires(T t) {
|
||||
concept BlockTransferDescriptor4D = requires(T t) {
|
||||
{ t.k0 } -> SizeType;
|
||||
{ t.m_n } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
@@ -211,11 +211,12 @@ concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info (Bwd direction).
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info
|
||||
// for 4D thread slices.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockTransferBwd = requires(T t) {
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd;
|
||||
concept SpecifiesBlockTransfer4D = requires(T t) {
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor4D;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor4D;
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
|
||||
@@ -184,9 +184,9 @@ consteval auto diagnose_block_transfer(const char* prefix) -> std::string {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// BlockTransferDescriptorBwd diagnostics (requires k_batch_size)
|
||||
// BlockTransferDescriptor4D diagnostics (requires k_batch_size)
|
||||
template <typename T, typename BT>
|
||||
consteval auto diagnose_block_transfer_bwd(const char* prefix) -> std::string {
|
||||
consteval auto diagnose_block_transfer_4d(const char* prefix) -> std::string {
|
||||
std::string msg;
|
||||
|
||||
if constexpr (requires(BT t) { t.k0; }) {
|
||||
@@ -500,7 +500,7 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransfer() -> std::string {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string {
|
||||
consteval auto detailed_diagnostic_SpecifiesBlockTransfer4D() -> std::string {
|
||||
std::string msg;
|
||||
|
||||
constexpr bool has_transfer = requires { T::transfer; };
|
||||
@@ -510,16 +510,20 @@ consteval auto detailed_diagnostic_SpecifiesBlockTransferBwd() -> std::string {
|
||||
return msg;
|
||||
}
|
||||
|
||||
constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptorBwd; };
|
||||
constexpr bool has_a = requires { { T::transfer.a.block_transfer } -> BlockTransferDescriptor4D; };
|
||||
msg += " → T::transfer.a: " + std::string(CHECK_MARK(has_a)) + "\n";
|
||||
if constexpr (!has_a) {
|
||||
msg += " → T::transfer.a.block_transfer: [✗] (missing or wrong type)\n";
|
||||
} else {
|
||||
msg += detail::diagnose_block_transfer_4d<T, decltype(T::transfer.a.block_transfer)>("transfer.a.block_transfer");
|
||||
}
|
||||
|
||||
constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptorBwd; };
|
||||
constexpr bool has_b = requires { { T::transfer.b.block_transfer } -> BlockTransferDescriptor4D; };
|
||||
msg += " → T::transfer.b: " + std::string(CHECK_MARK(has_b)) + "\n";
|
||||
if constexpr (!has_b) {
|
||||
msg += " → T::transfer.b.block_transfer: [✗] (missing or wrong type)\n";
|
||||
} else {
|
||||
msg += detail::diagnose_block_transfer_4d<T, decltype(T::transfer.b.block_transfer)>("transfer.b.block_transfer");
|
||||
}
|
||||
|
||||
constexpr bool has_c = requires { { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; };
|
||||
|
||||
@@ -242,7 +242,7 @@ template <typename T>
|
||||
struct BwdXdlAlgorithm {
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransferBwd)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer4D)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
@@ -252,7 +252,7 @@ struct BwdXdlAlgorithm {
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransferBwd;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer4D;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
@@ -269,7 +269,7 @@ struct BwdXdlAlgorithm {
|
||||
"Concepts for BwdXdl Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer4D) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
@@ -283,7 +283,7 @@ template <typename T>
|
||||
struct BwdXdlV3Algorithm {
|
||||
CHECK_CONCEPT(T, ConvAlgorithmDescriptor)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadBlock)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransferBwd)
|
||||
CHECK_CONCEPT(T, SpecifiesBlockTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesLdsTransfer)
|
||||
CHECK_CONCEPT(T, SpecifiesThreadClusterAccessOrder)
|
||||
CHECK_CONCEPT(T, SpecifiesSourceAccessOrder)
|
||||
@@ -293,7 +293,7 @@ struct BwdXdlV3Algorithm {
|
||||
|
||||
static constexpr bool c1 = c_ConvAlgorithmDescriptor;
|
||||
static constexpr bool c2 = c_SpecifiesThreadBlock;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransferBwd;
|
||||
static constexpr bool c3 = c_SpecifiesBlockTransfer;
|
||||
static constexpr bool c4 = c_SpecifiesLdsTransfer;
|
||||
static constexpr bool c5 = c_SpecifiesThreadClusterAccessOrder;
|
||||
static constexpr bool c6 = c_SpecifiesSourceAccessOrder;
|
||||
@@ -310,7 +310,7 @@ struct BwdXdlV3Algorithm {
|
||||
"Concepts for BwdXdlV3 Algorithm:\n") +
|
||||
DIAGNOSTIC_LINE(ConvAlgorithmDescriptor) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadBlock) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransferBwd) +
|
||||
DIAGNOSTIC_LINE(SpecifiesBlockTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesLdsTransfer) +
|
||||
DIAGNOSTIC_LINE(SpecifiesThreadClusterAccessOrder) +
|
||||
DIAGNOSTIC_LINE(SpecifiesSourceAccessOrder) +
|
||||
|
||||
@@ -12,9 +12,9 @@ namespace ck_tile::builder::factory::internal {
|
||||
// Block transfer parameters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order{};
|
||||
ck::Array<size_t, 3> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
@@ -22,15 +22,15 @@ struct BlockTransfer
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <size_t ThreadSliceDim = 3>
|
||||
struct BwdBlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
ck::Array<size_t, 4> thread_cluster_order = {0, 0, 0, 0};
|
||||
ck::Array<size_t, 4> src_access_order = {0, 0, 0, 0};
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadSliceDim> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
@@ -55,7 +55,7 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
}
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
|
||||
constexpr auto SetBwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
@@ -68,27 +68,25 @@ constexpr BwdBlockTransfer SetBwdConvBlockTransfer()
|
||||
|
||||
if constexpr (array_length == 3)
|
||||
{
|
||||
return BwdBlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
return BwdBlockTransfer<3>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else if constexpr (array_length == 4)
|
||||
{
|
||||
return BwdBlockTransfer{
|
||||
return BwdBlockTransfer<4>{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2], block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2], src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = lds_cfg.is_direct_load,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user