Merge commit '42048bdb7d8d931966af76c6dacfedce1c9da90a' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-28 17:20:56 +00:00
parent 78b36a13ab
commit dbadcf487a
40 changed files with 3191 additions and 780 deletions

View File

@@ -9,6 +9,7 @@ See the [main builder documentation](../README.md) for an overview.
The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation.
1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition.
This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels
2. **Description Generation**: The `describe<Instance>()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object.
@@ -48,6 +49,15 @@ The reflection system (`ckr::describe`) currently supports the following convolu
- **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`)
- **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`)
- **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`)
- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`)
- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`)
- **Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle`)
- **V3 Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3`)
- **Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffle`)
- **V3 Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffleV3`)
- **V3 Wmma Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`)
These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation.
@@ -59,15 +69,6 @@ The following instance types are **not yet supported** by the reflection system:
- Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc.
- Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1`
- **WMMA Variants** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`)
- Uses WMMA-specific parameters like `MPerWmma`, `NPerWmma`, `MRepeat`, `NRepeat`
- Different tile transfer structure incompatible with current `ConvTraits`
- **Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`)
- Uses different layout naming: `InLayout`, `WeiLayout`, `OutLayout` instead of `ALayout`, `BLayout`, `ELayout`
- Different specialization type: `ConvBackwardWeightSpecialization` vs `ConvForwardSpecialization`
- Missing several members expected by forward convolution traits
### Future Work
To support these additional instance types, the reflection system would need:

View File

@@ -29,30 +29,7 @@ conv::ConvDescription describe()
const auto traits = conv::instance_to_conv_traits<Instance>();
return conv::ConvDescription(
conv::ConvSignatureInfo{
.spatial_dim = traits.spatial_dim,
.direction = traits.direction,
.input_layout = traits.layout[0],
.weight_layout = traits.layout[1],
.output_layout = traits.layout[2],
.data_type = traits.data_type,
.input_element_op = traits.input_element_op,
.weight_element_op = traits.weight_element_op,
.output_element_op = traits.output_element_op,
},
conv::GemmAlgorithmInfo{
.thread_block_size = traits.thread_block_size,
.tile_dims = traits.tile_dims,
.warp_gemm = traits.warp_gemm,
.a_tile_transfer = traits.a_tile_transfer,
.b_tile_transfer = traits.b_tile_transfer,
.c_tile_transfer = traits.c_tile_transfer,
.pipeline_version = traits.pipeline_version,
.pipeline_scheduler = traits.pipeline_scheduler,
.conv_specialization = traits.conv_specialization,
.padding = traits.gemm_padding,
},
[]<typename T = Instance>() { return reflect::instance_string<T>(); });
traits, []<typename T = Instance>() { return reflect::instance_string<T>(); });
}
} // namespace ck_tile::reflect

View File

@@ -29,44 +29,12 @@
#include <ck_tile/builder/reflect/description.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
namespace ck_tile::reflect {
namespace conv {
/// @brief Signature information for a convolution operation
/// Contains high-level properties that define the convolution's interface,
/// including dimensionality, data layout, data types, and elementwise operations.
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
builder::TensorLayout input_layout;
builder::TensorLayout weight_layout;
builder::TensorLayout output_layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
/// @brief Algorithm configuration for a convolution kernel
/// Contains low-level implementation details including thread block configuration,
/// tile dimensions, memory access patterns, and pipeline settings.
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
builder::ConvSpecialization conv_specialization;
builder::GemmPadding padding;
};
/// @brief Provides human-readable descriptions of convolution kernel instances
/// Generates formatted text descriptions at various levels of detail for
/// understanding and documenting convolution kernel configurations.
@@ -74,16 +42,12 @@ class ConvDescription : public Description
{
public:
/// @brief Constructor for ConvDescription
/// @param sig The signature information containing high-level convolution properties
/// @param algo The algorithm configuration containing low-level implementation details
/// @param traits The ConvTraits object containing all relevant signature and algorithm
/// information
/// @param instance_string_getter A callable that returns a string representation of the
/// instance
ConvDescription(ConvSignatureInfo sig,
GemmAlgorithmInfo algo,
std::function<std::string()> instance_string_getter)
: signature_(std::move(sig)),
algorithm_(std::move(algo)),
instance_string_getter_(std::move(instance_string_getter))
ConvDescription(ConvTraits traits, std::function<std::string()> instance_string_getter)
: traits_(std::move(traits)), instance_string_getter_(std::move(instance_string_getter))
{
}
@@ -92,7 +56,7 @@ class ConvDescription : public Description
std::string brief() const override
{
std::ostringstream oss;
oss << signature_.spatial_dim << "D " << signature_.direction << " convolution";
oss << traits_.spatial_dim << "D " << traits_.direction << " convolution";
return oss.str();
}
@@ -101,39 +65,42 @@ class ConvDescription : public Description
std::string detailed() const override
{
TreeFormatter f;
f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel");
f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature_.data_type);
f.writeLine(2, "Input Layout: ", signature_.input_layout);
f.writeLine(2, "Weight Layout: ", signature_.weight_layout);
f.writeLine(2, "Output Layout: ", signature_.output_layout);
f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op);
f.writeLine(2, "Tensor Type: ", traits_.data_type);
f.writeLine(2, "Input Layout: ", traits_.layout[0]);
f.writeLine(2, "Weight Layout: ", traits_.layout[1]);
f.writeLine(2, "Output Layout: ", traits_.layout[2]);
f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op);
f.writeLast(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm_.thread_block_size);
f.writeLine(2, "Thread block size: ", traits_.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm_.tile_dims.m,
traits_.tile_dims.m,
"×",
algorithm_.tile_dims.n,
traits_.tile_dims.n,
"×",
algorithm_.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm_.padding);
f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization);
traits_.tile_dims.k);
if(traits_.gemm_padding)
f.writeLine(
2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT));
else
f.writeLine(2, "Struct does not contain optional gemm_padding argument");
f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm_.pipeline_scheduler);
f.writeLine(2, "Pipeline version: ", traits_.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm_.warp_gemm.gemm_m, "×", algorithm_.warp_gemm.gemm_n);
f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm_.warp_gemm.m_iter,
traits_.warp_gemm.m_iter,
"×",
algorithm_.warp_gemm.n_iter);
traits_.warp_gemm.n_iter);
// Memory Access section
f.writeLast(2, "Memory access:");
@@ -141,99 +108,126 @@ class ConvDescription : public Description
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm_.a_tile_transfer.tile_dimensions.k0,
traits_.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm_.a_tile_transfer.tile_dimensions.m_or_n,
traits_.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm_.a_tile_transfer.tile_dimensions.k1,
traits_.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(4,
"The innermost K subdimension size: ",
algorithm_.a_tile_transfer.transfer_params.k1);
f.writeLine(
4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[0],
traits_.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[1],
traits_.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm_.a_tile_transfer.transfer_params.src_access_order[0],
traits_.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm_.a_tile_transfer.transfer_params.src_access_order[1],
traits_.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm_.a_tile_transfer.transfer_params.src_access_order[2]);
traits_.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm_.a_tile_transfer.transfer_params.src_vector_dim);
traits_.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm_.a_tile_transfer.transfer_params.src_scalar_per_vector);
traits_.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm_.b_tile_transfer.tile_dimensions.k0,
traits_.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm_.b_tile_transfer.tile_dimensions.m_or_n,
traits_.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm_.b_tile_transfer.tile_dimensions.k1,
traits_.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(4,
"The innermost K subdimension size: ",
algorithm_.b_tile_transfer.transfer_params.k1);
f.writeLine(
4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[0],
traits_.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[1],
traits_.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm_.b_tile_transfer.transfer_params.src_access_order[0],
traits_.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm_.b_tile_transfer.transfer_params.src_access_order[1],
traits_.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm_.b_tile_transfer.transfer_params.src_access_order[2]);
traits_.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm_.b_tile_transfer.transfer_params.src_vector_dim);
traits_.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm_.b_tile_transfer.transfer_params.src_scalar_per_vector);
traits_.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm_.c_tile_transfer.thread_cluster_dims[0],
traits_.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[1],
traits_.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[2],
traits_.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
traits_.c_tile_transfer.thread_cluster_dims[3]);
f.writeLine(4,
"Vector access (GMEM write) instruction size: ",
algorithm_.c_tile_transfer.scalar_per_vector);
traits_.c_tile_transfer.scalar_per_vector);
if(traits_.num_gemm_k_prefetch_stage)
f.writeLine(
2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"num_gemm_k_prefetch_stage parameter");
if(traits_.max_transpose_transfer_src_scalar_per_vector)
f.writeLine(2,
"Max Transpose transfer scr scalar per vector: ",
traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0));
else
f.writeLine(2,
"Struct does not contain optional "
"max_transpose_transfer_src_scalar_per_vector parameter");
if(traits_.max_transpose_dst_scalar_per_vector)
f.writeLine(2,
"Max Transpose dst scalar per vector: ",
traits_.max_transpose_dst_scalar_per_vector.value_or(0));
else
f.writeLine(
2,
"Struct does not contain optional max_transpose_dst_scalar_per_vector parameter");
if(traits_.num_groups_to_merge)
f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0));
else
f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter");
return f.getString();
}
@@ -242,8 +236,7 @@ class ConvDescription : public Description
std::string instance_string() const override { return instance_string_getter_(); }
private:
ConvSignatureInfo signature_;
GemmAlgorithmInfo algorithm_;
ConvTraits traits_;
std::function<std::string()> instance_string_getter_;
};

View File

@@ -88,7 +88,7 @@ struct ConvTraits
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
builder::GemmPadding gemm_padding;
std::optional<builder::GemmPadding> gemm_padding = std::nullopt;
builder::ConvSpecialization conv_specialization;
// --- Algorithm Information ---
@@ -102,8 +102,14 @@ struct ConvTraits
OutputTileTransferInfo c_tile_transfer;
std::optional<int> num_gemm_k_prefetch_stage = std::nullopt;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::optional<int> max_transpose_transfer_src_scalar_per_vector = std::nullopt;
std::optional<int> max_transpose_dst_scalar_per_vector = std::nullopt;
std::optional<int> num_groups_to_merge = std::nullopt;
};
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kABK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,57 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector,
.num_groups_to_merge = InstTraits::kNumGroupsToMerge,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,48 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,50 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kKPerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kKPerBlock),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,56 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
.max_transpose_transfer_src_scalar_per_vector =
InstTraits::kMaxTransposeTransferSrcScalarPerVector,
.max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector,
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,53 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = bwd_wei_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::InDataType>(),
.input_element_op = elementwise_op<typename InstTraits::InElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::WeiElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::OutElementwiseOperation>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(InstTraits::kK0PerBlock),
.a_tile_transfer =
conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.b_tile_transfer =
conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1, InstTraits::kK0PerBlock),
.warp_gemm = conv_traits_xdl_warp_gemm_params<InstTraits>(),
.c_tile_transfer =
{.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl},
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include "ck_tile/builder/reflect/conv_traits.hpp"
#include "ck_tile/builder/reflect/conv_traits_helpers.hpp"
#include "ck_tile/builder/reflect/instance_traits.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
namespace ck_tile::reflect::conv {
/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle
template <typename Instance>
requires HasInstanceTraits<Instance> &&
std::same_as<typename InstanceTraits<Instance>::device_kernel_tag,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag>
constexpr ConvTraits instance_to_conv_traits()
{
using InstTraits = InstanceTraits<Instance>;
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),
.gemm_padding = gemm_spec<Instance>(),
.conv_specialization = conv_spec<Instance>(),
.thread_block_size = InstTraits::kBlockSize,
.tile_dims = conv_traits_data_tile<InstTraits>(),
.a_tile_transfer = conv_traits_a_transfer_params<InstTraits>(InstTraits::kK1),
.b_tile_transfer = conv_traits_b_transfer_params<InstTraits>(InstTraits::kK1),
.warp_gemm = conv_traits_wmma_warp_gemm_params<InstTraits>(),
.c_tile_transfer = conv_traits_wmma_c_tile_transfer<InstTraits>(),
.num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage,
.pipeline_version = get_pipeline_version<InstTraits>(),
.pipeline_scheduler = get_pipeline_scheduler<InstTraits>(),
};
}
} // namespace ck_tile::reflect::conv

View File

@@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits()
return ConvTraits{
.spatial_dim = InstTraits::kSpatialDim,
.direction = conv_direction<Instance>(),
.layout = conv_layout<Instance>(),
.data_type = conv_data_type<Instance>(),
.layout = fwd_conv_layout<Instance>(),
.data_type = conv_data_type<typename InstTraits::ADataType>(),
.input_element_op = elementwise_op<typename InstTraits::AElementwiseOperation>(),
.weight_element_op = elementwise_op<typename InstTraits::BElementwiseOperation>(),
.output_element_op = elementwise_op<typename InstTraits::CDEElementwiseOperation>(),

View File

@@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv {
// SECTION 1: ENUM CONVERSIONS
// ============================================================================
// Forward convolution layout concept - checks for A/B/E layout types
template <typename T>
concept HasFwdConvLayouts = requires {
typename T::ALayout;
typename T::BLayout;
typename T::ELayout;
};
// Backwards weight layout concept - checks for In, wei and out layouts
template <typename T>
concept HasBwdWeiLayouts = requires {
typename T::InLayout;
typename T::WeiLayout;
typename T::OutLayout;
};
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value.
@@ -322,12 +338,25 @@ constexpr builder::ConvSpecialization conv_spec()
// Tensor Layouts
// ----------------------------------------------------------------------------
// Helper variable template to check if CK layout enums match
template <typename A,
typename B,
typename E,
typename ExpectedA,
typename ExpectedB,
typename ExpectedE>
inline constexpr bool layouts_are =
std::is_same_v<A, ExpectedA> && std::is_same_v<B, ExpectedB> && std::is_same_v<E, ExpectedE>;
/// @brief Helper function to report unsupported layout combinations with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
/// @details This consteval function is designed to fail at compile time with a descriptive
/// error message when an unsupported layout combination is encountered.
template <typename A, typename B, typename E, int SpatialDim>
[[noreturn]] consteval void report_unsupported_layout_error()
{
// This will produce a compile-time error with the exception message
throw "Unsupported convolution layout combination detected!\n"
"The combination of ALayout, BLayout, and ELayout template parameters\n"
"is not recognized for the given spatial dimension.\n"
@@ -335,111 +364,99 @@ template <typename A, typename B, typename E, int SpatialDim>
"Check the conv_layout() function for the list of supported layout combinations.";
}
/// @brief Derives the grouped convolution layout from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array<builder::TensorLayout, 3> containing the layouts for:
/// - [0] Input tensor layout
/// - [1] Weight tensor layout
/// - [2] Output tensor layout
/// @details This function examines the Instance's ALayout, BLayout, and ELayout types
/// along with the spatial dimension to determine the appropriate layout configuration.
///
/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions).
/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants.
///
/// @note Compilation will fail with a clear error message if the layout combination
/// is not supported for the given spatial dimension.
///
/// TODO: If we don't check for supported layouts, this function can be simplified.
template <typename Instance>
constexpr std::array<builder::TensorLayout, 3> conv_layout()
template <typename A, typename B, typename E, int kSpatialDim>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using A = typename InstTraits::ALayout;
using B = typename InstTraits::BLayout;
using E = typename InstTraits::ELayout;
namespace ctl = ck::tensor_layout::convolution;
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
namespace ctl = ck::tensor_layout::convolution;
using enum builder::TensorLayout;
// Helper to check if layouts match expected types
constexpr auto layouts_match = []<typename ExpA, typename ExpB, typename ExpE>() {
return std::is_same_v<A, ExpA> && std::is_same_v<B, ExpB> && std::is_same_v<E, ExpE>;
};
switch(kSpatialDim)
{
case 1:
if constexpr(layouts_are<A, B, E, ctl::GNWC, ctl::GKXC, ctl::GNWK>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>)
return layouts(GNWC, GKXC, GNWK);
if constexpr(layouts_are<A, B, E, ctl::NWGC, ctl::GKXC, ctl::NWGK>)
return layouts(NWGC, GKXC, NWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKXC, ctl::NGKW>)
return layouts(NGCW, GKXC, NGKW);
if constexpr(layouts_are<A, B, E, ctl::NGCW, ctl::GKCX, ctl::NGKW>)
return layouts(NGCW, GKCX, NGKW);
break;
case 2:
if constexpr(layouts_are<A, B, E, ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>)
return layouts(GNHWC, GKYXC, GNHWK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>)
return layouts(NHWGC, GKYXC, NHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>)
return layouts(NGCHW, GKYXC, NGKHW);
if constexpr(layouts_are<A, B, E, ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>)
return layouts(NGCHW, GKCYX, NGKHW);
break;
case 3:
if constexpr(layouts_are<A, B, E, ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>)
return layouts(GNDHWC, GKZYXC, GNDHWK);
if constexpr(layouts_are<A, B, E, ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>)
return layouts(NDHWGC, GKZYXC, NDHWGK);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>)
return layouts(NGCDHW, GKZYXC, NGKDHW);
if constexpr(layouts_are<A, B, E, ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>)
return layouts(NGCDHW, GKCZYX, NGKDHW);
break;
}
// Helper to construct layout array
constexpr auto make_layouts = [](auto in, auto weight, auto out) {
return std::array<builder::TensorLayout, 3>{in, weight, out};
};
// If we reach here, the layout combination is not supported
// Call consteval function to trigger a compile-time error with a clear message
report_unsupported_layout_error<A, B, E, kSpatialDim>();
constexpr int spatial_dim = InstTraits::kSpatialDim;
// This return is unreachable but needed to satisfy the compiler
return layouts(GNHWC, GKYXC, GNHWK);
}
if constexpr(spatial_dim == 1)
{
if constexpr(layouts_match.template operator()<ctl::GNWC, ctl::GKXC, ctl::GNWK>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>())
return make_layouts(GNWC, GKXC, GNWK);
else if constexpr(layouts_match.template operator()<ctl::NWGC, ctl::GKXC, ctl::NWGK>())
return make_layouts(NWGC, GKXC, NWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKXC, ctl::NGKW>())
return make_layouts(NGCW, GKXC, NGKW);
else if constexpr(layouts_match.template operator()<ctl::NGCW, ctl::GKCX, ctl::NGKW>())
return make_layouts(NGCW, GKCX, NGKW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNWC, GKXC, GNWK); // Unreachable
}
}
else if constexpr(spatial_dim == 2)
{
if constexpr(layouts_match.template operator()<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>())
return make_layouts(GNHWC, GKYXC, GNHWK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NHWGC, ctl::KYXGC, ctl::NHWGK>())
return make_layouts(NHWGC, GKYXC, NHWGK);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKYXC, ctl::NGKHW>())
return make_layouts(NGCHW, GKYXC, NGKHW);
else if constexpr(layouts_match.template operator()<ctl::NGCHW, ctl::GKCYX, ctl::NGKHW>())
return make_layouts(NGCHW, GKCYX, NGKHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
}
else if constexpr(spatial_dim == 3)
{
if constexpr(layouts_match.template operator()<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>())
return make_layouts(GNDHWC, GKZYXC, GNDHWK);
else if constexpr(layouts_match
.template operator()<ctl::NDHWGC, ctl::GKZYXC, ctl::NDHWGK>())
return make_layouts(NDHWGC, GKZYXC, NDHWGK);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKZYXC, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKZYXC, NGKDHW);
else if constexpr(layouts_match
.template operator()<ctl::NGCDHW, ctl::GKCZYX, ctl::NGKDHW>())
return make_layouts(NGCDHW, GKCZYX, NGKDHW);
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable
}
}
else
{
report_unsupported_layout_error<A, B, E, spatial_dim>();
return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto fwd_conv_layout()
requires HasFwdConvLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::ALayout;
using B = typename InstanceTraits<Instance>::BLayout;
using E = typename InstanceTraits<Instance>::ELayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return An std::array corresponding to the tensor layouts:
/// index 0 -> Input layout
/// index 1 -> Weight layout
/// index 2 -> Output layout
template <typename Instance>
constexpr auto bwd_wei_conv_layout()
requires HasBwdWeiLayouts<InstanceTraits<Instance>>
{
using A = typename InstanceTraits<Instance>::InLayout;
using B = typename InstanceTraits<Instance>::WeiLayout;
using E = typename InstanceTraits<Instance>::OutLayout;
return conv_layout<A, B, E, InstanceTraits<Instance>::kSpatialDim>();
}
// ----------------------------------------------------------------------------
@@ -447,13 +464,11 @@ constexpr std::array<builder::TensorLayout, 3> conv_layout()
// ----------------------------------------------------------------------------
/// @brief Helper function to report unsupported data type with a clear error message.
/// @details This consteval function uses throw (not static_assert) to ensure the error is not
/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message.
template <typename ADataType>
template <typename DataTypeFromInstance>
[[noreturn]] consteval void report_unsupported_data_type_error()
{
throw "Unsupported data type detected!\n"
"The ADataType is not recognized.\n"
"The DataTypeFromInstance is not recognized.\n"
"Supported types are: ck::half_t (FP16), ck::Tuple<ck::half_t, ck::half_t> (FP16_FP16), "
"ck::bhalf_t (BF16), ck::Tuple<ck::bhalf_t, ck::bhalf_t> (BF16_BF16), float (FP32), "
"ck::Tuple<float, float> (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t "
@@ -462,62 +477,44 @@ template <typename ADataType>
"Please verify that your kernel instance uses a supported data type.";
}
/// @brief Derives the data type from a device kernel Instance type.
/// @tparam Instance The device kernel instance type.
/// @return A builder::DataType enum value representing the input data type.
/// @details This function examines the Instance's ADataType to determine the data type
/// used for the input tensor. The function supports various floating-point and integer
/// types, including tuple types for mixed-precision operations.
///
/// Supported data types include:
/// - FP16 (ck::half_t)
/// - FP16_FP16 (ck::Tuple<ck::half_t, ck::half_t>)
/// - BF16 (ck::bhalf_t)
/// - BF16_BF16 (ck::Tuple<ck::bhalf_t, ck::bhalf_t>)
/// - FP32 (float)
/// - FP32_FP32 (ck::Tuple<float, float>)
/// - FP64 (double)
/// - FP8 (ck::f8_t)
/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t)
/// - I8 (int8_t)
/// - I8_I8 (ck::Tuple<int8_t, int8_t>)
/// - U8 (uint8_t)
template <typename Instance>
/// @brief Derives the data type from a device kernel `Instance` type.
/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8).
// Note: maybe move to types.hpp?
template <typename DataTypeFromInstance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
using enum builder::DataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
if constexpr(std::is_same_v<DataTypeFromInstance, ck::half_t>)
return FP16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::half_t, ck::half_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::half_t, ck::half_t>>)
return FP16_FP16;
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bhalf_t>)
return BF16;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<ck::bhalf_t, ck::bhalf_t>>)
return BF16_BF16;
else if constexpr(std::is_same_v<ADataType, float>)
else if constexpr(std::is_same_v<DataTypeFromInstance, float>)
return FP32;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<float, float>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<float, float>>)
return FP32_FP32;
else if constexpr(std::is_same_v<ADataType, double>)
else if constexpr(std::is_same_v<DataTypeFromInstance, double>)
return FP64;
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::f8_t>)
return FP8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_fnuz_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_fnuz_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, ck::bf8_ocp_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::bf8_ocp_t>)
return BF8;
else if constexpr(std::is_same_v<ADataType, int8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, int8_t>)
return I8;
else if constexpr(std::is_same_v<ADataType, ck::Tuple<int8_t, int8_t>>)
else if constexpr(std::is_same_v<DataTypeFromInstance, ck::Tuple<int8_t, int8_t>>)
return I8_I8;
else if constexpr(std::is_same_v<ADataType, uint8_t>)
else if constexpr(std::is_same_v<DataTypeFromInstance, uint8_t>)
return U8;
else
{
report_unsupported_data_type_error<ADataType>();
report_unsupported_data_type_error<DataTypeFromInstance>();
return FP32; // Unreachable
}
}
@@ -736,4 +733,92 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler()
}
}
// ============================================================================
// SECTION 4: Helper functions for common structures often used in reflection
// ============================================================================
template <typename InstTraits>
constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock)
{
return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0};
}
template <typename InstTraits>
constexpr InputTileTransferInfo
conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
}
template <typename InstTraits>
constexpr InputTileTransferInfo
conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock)
{
return InputTileTransferInfo{
.tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1},
.transfer_params = {.k1 = _k1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma,
.gemm_n = InstTraits::kNPerWmma,
.m_iter = InstTraits::kMRepeat,
.n_iter = InstTraits::kNRepeat};
}
template <typename InstTraits>
constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params()
{
return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle},
.thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0],
InstTraits::kCDEThreadClusterLengths[1],
InstTraits::kCDEThreadClusterLengths[2],
InstTraits::kCDEThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector};
}
template <typename InstTraits>
constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer()
{
return OutputTileTransferInfo{
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
}
} // namespace ck_tile::reflect::conv

View File

@@ -3,6 +3,18 @@
#pragma once
// Fwd instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
// Bwd weight instances
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"

View File

@@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -158,7 +162,9 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -175,13 +181,13 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kKPerBlock = KPerBlock;
static constexpr ck::index_t kABK1 = ABK1;
static constexpr ck::index_t kK1 = ABK1;
static constexpr ck::index_t kMPerWmma = MPerWmma;
static constexpr ck::index_t kNPerWmma = NPerWmma;
static constexpr ck::index_t kMRepeat = MRepeat;
@@ -195,27 +201,46 @@ struct InstanceTraits<
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -231,7 +256,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -251,30 +276,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kKPerBlock; // 18. KPerBlock
oss << "," << kABK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kKPerBlock; // 18. KPerBlock
oss << "," << kK1; // 19. ABK1
oss << "," << kMPerWmma; // 20. MPerWmma
oss << "," << kNPerWmma; // 21. NPerWmma
oss << "," << kMRepeat; // 22. MRepeat
oss << "," << kNRepeat; // 23. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMRepeatPerShuffle; // 38.
oss << "," << kCShuffleNRepeatPerShuffle; // 39.
oss << ","

View File

@@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
@@ -152,7 +156,10 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -169,7 +176,7 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -188,22 +195,36 @@ struct InstanceTraits<
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
@@ -211,6 +232,9 @@ struct InstanceTraits<
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -220,7 +244,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -240,30 +264,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. K1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 15. BlockSize
oss << "," << kMPerBlock; // 16. MPerBlock
oss << "," << kNPerBlock; // 17. NPerBlock
oss << "," << kK0PerBlock; // 18. K0PerBlock
oss << "," << kK1; // 19. K1
oss << "," << kMPerXDL; // 20. MPerXDL
oss << "," << kNPerXDL; // 21. NPerXDL
oss << "," << kMXdlPerWave; // 22. MXdlPerWave
oss << "," << kNXdlPerWave; // 23. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 24.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 25.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 26.
oss << "," << kABlockTransferSrcVectorDim; // 27.
oss << "," << kABlockTransferSrcScalarPerVector; // 28.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 29.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 29.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 31.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 32.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 33.
oss << "," << kBBlockTransferSrcVectorDim; // 34.
oss << "," << kBBlockTransferSrcScalarPerVector; // 35.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39.
oss << ","

View File

@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel
struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -161,7 +166,9 @@ struct InstanceTraits<
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -176,7 +183,7 @@ struct InstanceTraits<
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -201,12 +208,21 @@ struct InstanceTraits<
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
@@ -215,13 +231,26 @@ struct InstanceTraits<
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -237,7 +266,7 @@ struct InstanceTraits<
oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -255,30 +284,30 @@ struct InstanceTraits<
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -160,7 +165,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -175,7 +182,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -199,26 +206,45 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -234,7 +260,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
oss << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -252,30 +278,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeightTw
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -148,8 +153,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
false>> // Use false to match with the default value
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag;
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -164,15 +170,15 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerWMMA = MPerWMMA;
static constexpr ck::index_t kNPerWMMA = NPerWMMA;
static constexpr ck::index_t kMPerWmma = MPerWMMA;
static constexpr ck::index_t kNPerWmma = NPerWMMA;
static constexpr ck::index_t kMRepeat = MRepeat;
static constexpr ck::index_t kNRepeat = NRepeat;
static constexpr ck::index_t kCShuffleMRepeatPerShuffle = CShuffleMRepeatPerShuffle;
@@ -184,26 +190,43 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::LoopScheduler kLoopSched = LoopSched;
static constexpr ck::PipelineVersion kPipelineVer = PipelineVer;
@@ -216,7 +239,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -234,30 +257,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerWMMA; // 18. MPerWMMA
oss << "," << kNPerWMMA; // 19. NPerWMMA
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerWmma; // 18. MPerWMMA
oss << "," << kNPerWmma; // 19. NPerWMMA
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -156,8 +161,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
MaxTransposeTransferDstScalarPerVector>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag;
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -172,13 +178,13 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kKPerBlock = KPerBlock;
static constexpr ck::index_t kABK1 = ABK1;
static constexpr ck::index_t kK1 = ABK1;
static constexpr ck::index_t kMPerWmma = MPerWmma;
static constexpr ck::index_t kNPerWmma = NPerWmma;
static constexpr ck::index_t kMRepeat = MRepeat;
@@ -196,27 +202,46 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
ABlockTransferThreadClusterLengths_AK0_M_AK1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_AK0_M_AK1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_AK1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_BK0_N_BK1 =
BBlockTransferThreadClusterLengths_BK0_N_BK1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_BK0_N_BK1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_BK1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray<
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr int kCDEBlockTransferScalarPerVector =
CShuffleBlockTransferScalarPerVector_NPerBlock;
static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched;
static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer;
@@ -232,7 +257,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
oss << "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -250,30 +275,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_W
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kABK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kKPerBlock; // 16. KPerBlock
oss << "," << kK1; // 17. ABK1
oss << "," << kMPerWmma; // 18. MPerWmma
oss << "," << kNPerWmma; // 19. NPerWmma
oss << "," << kMRepeat; // 20. MRepeat
oss << "," << kNRepeat; // 21. NRepeat
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_AK0_M_AK1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_BK0_N_BK1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMRepeatPerShuffle; // 36.
oss << "," << kCShuffleNRepeatPerShuffle; // 37.
oss << ","

View File

@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -152,7 +157,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -167,43 +173,63 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
@@ -224,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -242,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
namespace ck_tile {
namespace reflect {
/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag
{
};
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
@@ -150,9 +155,12 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
ComputeTypeA_,
ComputeTypeB_>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag;
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
static constexpr ck::index_t kSpatialDim = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
@@ -167,7 +175,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr auto kConvBwdWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
@@ -182,28 +190,48 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
// A block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kAThreadClusterLengths =
detail::SequenceToArray<ABlockTransferThreadClusterLengths_K0_M_K1>::value;
static constexpr auto kAThreadClusterArrangeOrder =
detail::SequenceToArray<ABlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kABlockTransferSrcAccessOrder =
detail::SequenceToArray<ABlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
// B block transfer thread cluster dimensions (converted to std::array)
static constexpr auto kBThreadClusterLengths =
detail::SequenceToArray<BBlockTransferThreadClusterLengths_K0_N_K1>::value;
static constexpr auto kBThreadClusterArrangeOrder =
detail::SequenceToArray<BBlockTransferThreadClusterArrangeOrder>::value;
static constexpr auto kBBlockTransferSrcAccessOrder =
detail::SequenceToArray<BBlockTransferSrcAccessOrder_>::value;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr auto kCThreadClusterLengths = detail::SequenceToArray<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
@@ -222,7 +250,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
@@ -240,30 +268,30 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << kABlockTransferDstScalarPerVectorK1; // 27.
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34.
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","

View File

@@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
} // namespace ck::tensor_operation::device
/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag
{
};
namespace ck_tile::reflect {
// Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
@@ -176,6 +181,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
LoopSched,
PipelineVer>>
{
/// @brief Tag type identifying this device kernel variant
using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag;
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;

View File

@@ -12,7 +12,8 @@ The core components are:
- **`Args`**: A struct template that holds runtime parameters for a specific test case.
- **`Input`** and **`Output`**: Helper classes that groups operation inputs and outputs.
- **`Validator`**: A utility that performs on-GPU validation and integrates with GoogleTest/GoogleMock.
- **`run()`**: Invokes an algorithm on the GPU.
- **`validate()`**: A utility that performs on-GPU validation and integrates with GoogleTest/GoogleMock.
Together, these components enable a structured approach to kernel testing that mirrors the Given-When-Then pattern commonly used in behavior-driven development.
@@ -200,26 +201,27 @@ auto reference_outputs = ck_tile::builder::test::allocate_outputs(args);
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
```
#### `Validator<SIGNATURE>`
#### Validating Results
The `Validator` class encapsulates the validation logic. It performs on-GPU correctness checks by comparing two instances of the `Outputs` structure.
In order to actually verify that the results of the executed device operation are correct, they are compared against the reference output obtained in the previous step. This is done by calling `validate()` with the runtime arguments of the operation, as well as both the actual and reference output. This then yields a *`ValidationReport`*, a type which holds information about which tensors of the output are considered to be equivalent and which are considered to be different. Actually comparing the tensor elements is performed on the GPU to keep the tests fast.
```cpp
ck_tile::builder::test::Validator<SIGNATURE> validator(outputs.get(), reference_outputs.get());
const auto report = ck_tile::builder::test::validate(args, outputs.get(), reference_outputs.get());
```
The `Validator` provides methods that return GoogleMock matchers, enabling clean integration with GoogleTest:
`ValidationReport::get_errors()` returns a vector of tensors from both outputs which are considered to be incorrect, each error case exposes some information about what went wrong.
```cpp
EXPECT_THAT(validator.result(), validator.matches_reference_output());
for (const auto& e : report.get_errors()) {
std::cout << "error: " << e.tensor_name << " was incorrect!" << std::endl;
}
```
The `matches_reference_output()` matcher checks that the output is numerically correct within acceptable tolerances. The `Validator` can also provide more detailed diagnostics, such as:
GoogleTest/GoogleMock integration is provided using the `MatchesReference` matcher. This invokes `validate()` internally, and then turns the result into a GoogleMock-comparible value. Note that this function is closely tied to GoogleMock and the test setup that CK-Builder uses internally, and so is not exposed through the CK-Builder public interface.
- Maximum absolute error
- Maximum relative error
- Number of mismatched elements
- Specific locations of errors
```cpp
EXPECT_THAT(outputs.get(), MatchesReference(args, reference_outputs.get()));
```
## Complete Example
@@ -232,6 +234,7 @@ Here's a complete test that demonstrates the Given-When-Then pattern:
#include "ck_tile/builder/conv_builder.hpp"
#include "ck_tile/testing/tensor_memory_manager.hpp"
#include "ck_tile/testing/validator.hpp"
#include "testing_utils.hpp"
// Define the convolution signature
struct ConvSignature {
@@ -318,8 +321,7 @@ TEST(ConvolutionTest, Forward2D_FP16) {
ck_tile::builder::test::run(conv, args, inputs.get(), reference_outputs.get());
// Check the results
ck_tile::builder::test::Validator<SIGNATURE> validator(outputs.get(), reference_outputs.get());
EXPECT_THAT(validator.result(), validator.is_ok());
EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference_outputs.get()));
}
```
@@ -333,7 +335,7 @@ TEST(ConvolutionTest, Forward2D_FP16) {
4. **Flexibility**: The `Args` struct can be easily extended to support different test scenarios, `Inputs` and `Outputs` can be modified to support additional tensors where necessary, and alternatives to `init_inputs()` can be provided to support additional testing strategies.
5. **Integration**: The `Validator` integrates seamlessly with GoogleTest/GoogleMock, providing familiar assertion syntax.
5. **Integration**: `validate()` integrates seamlessly with GoogleTest/GoogleMock through `MatchesReference`, providing familiar assertion syntax.
6. **Maintainability**: Changes to the testing infrastructure are localized to the utility classes, not scattered across individual tests.

View File

@@ -6,6 +6,7 @@
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include <cstdint>
#include <cassert>
#include <concepts>
#include <array>
@@ -348,4 +349,115 @@ void clear_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
fill_tensor_buffer(desc, buffer, [value]([[maybe_unused]] size_t i) { return value; });
}
/// @brief Utility for copying a tensor from one layout to another
///
/// This function copies tensor data from `src_buffer` to `dst_buffer`,
/// changing the layout from `src_desc` to `dst_desc`. Note that the src and
/// dst tensor lengths must be compatible, otherwise this function may write
/// out of bounds.
///
/// @tparam DT The element datatype of both tensors.
/// @tparam RANK The rank (number of spatial dimensions) of the tensors.
///
/// @param src_desc The descriptor of the source tensor to copy from.
/// @param src_buffer The memory of the source tensor.
/// @param dst_desc The descriptor of the destination tensor to copy to.
/// @param dst_buffer The memory of the destination tensor.
template <DataType DT, size_t RANK>
void copy_tensor(const TensorDescriptor<DT, RANK>& src_desc,
const void* src_buffer,
const TensorDescriptor<DT, RANK>& dst_desc,
void* dst_buffer)
{
assert(src_desc.get_lengths() == dst_desc.get_lengths());
const auto src_strides = src_desc.get_strides();
const auto dst_strides = dst_desc.get_strides();
tensor_foreach(dst_desc.get_lengths(),
[src_buffer, dst_buffer, src_strides, dst_strides](const auto& index) {
using T = detail::cpp_type_t<DT>;
const auto* src = static_cast<const T*>(src_buffer);
auto* dst = static_cast<T*>(dst_buffer);
const auto src_off = calculate_offset(index, src_strides);
const auto dst_off = calculate_offset(index, dst_strides);
dst[dst_off] = src[src_off];
});
}
/// @brief Simple iterator implementation over tensors.
///
/// This type implements a simple "iterator" type for tensor types,
/// basically exposing operator[] for flat indices. This type is useful
/// to be able to provide a "pointer-like" object to API that does not
/// expect higher dimensional tensor types, and expects linear pointers
/// instead. Ideally, one just needs to replace the `T* ptr` with
/// `Iterator it` to update those API to be compatible with this type.
///
/// @note This is not intended to be a full implementation of the C++
/// iterator concept. For example, it does not really hold any state,
/// because that is not really useful anyway.
///
/// @tparam DT The datatype of the tensor to iterate over. Note that this
/// is only here for reference purposes, the actual data type of the backing
/// memory is provided via the backing iterator type.
/// @tparam RANK The rank (number of spatial dimensions) of the tensors.
/// @tparam Iterator The backing iterator type. This can be a (non-void)
/// pointer type.
template <DataType DT, size_t RANK, typename Iterator>
struct FlatTensorIterator
{
/// @brief Construct a FlatTensorIterator.
///
/// Construct a FlatTensorIterator from a tensor descriptor and a backing
/// iterator. The backing iterator can just be a non-void pointer type,
/// note that the result of FlatTensorIterator::operator[] is the same as
/// that of Iterator::operator[].
///
/// @param desc The descriptor of the tensor to iterate.
/// @param inner The inner iterator, for example a (non-void) pointer.
FlatTensorIterator(const TensorDescriptor<DT, RANK>& desc, Iterator inner)
: iter_(desc.get_lengths()), strides_(desc.get_strides()), inner_(inner)
{
}
/// @brief Return the value at a particular flat index.
///
/// This function returns the value of the tensor at flat coordinate
/// `flat_index`. This index is then unflattened into a multi-dimensional
/// index according to the way described in `NdIter`, and a tensor offset
/// is computed from that according to `calculate_offset`. The value at
/// that offset in the inner iterator is then the return value of this
/// function.
///
/// @note NdIter iterates such that the inner dimension (right-most value
/// in the tensor shape) changes fastest.
///
/// @note This function performs no bounds checking.
///
/// @param flat_index The flat index into this tensor.
///
/// @pre flat_index < numel()
///
/// @see NdIter
__host__ __device__ auto& operator[](size_t flat_index) const
{
const auto index = iter_(flat_index);
const auto offset = calculate_offset(index, strides_);
return inner_[offset];
}
/// @brief Return the total number of elements to iterate over.
///
/// @see NdIter::numel()
__host__ __device__ size_t numel() const { return iter_.numel(); }
private:
NdIter<RANK> iter_;
Extent<RANK> strides_;
Iterator inner_;
};
template <DataType DT, size_t RANK, typename Iterator>
FlatTensorIterator(const TensorDescriptor<DT, RANK>&,
Iterator) -> FlatTensorIterator<DT, RANK, Iterator>;
} // namespace ck_tile::builder::test

View File

@@ -8,6 +8,7 @@
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/library/utility/gpu_verification.hpp"
#include <string_view>
#include <vector>
#include <algorithm>
@@ -48,8 +49,8 @@ struct ValidationReport
/// The total number of elements in each tensor.
uint64_t total_elements;
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
/// Set to true if both tensors have all their elements be 0.
bool both_all_zero;
// Max error.
double max_error;
@@ -59,7 +60,7 @@ struct ValidationReport
/// If both tensors are all zero, it indicates either an incorrect testing setup
/// or an issue with the testing framework. For that reason we also consider that
/// a failure.
bool is_all_zero() const { return zero_elements == total_elements; }
bool is_all_zero() const { return both_all_zero; }
/// @brief Return whether the check associated to this case was successful.
///
@@ -86,7 +87,7 @@ struct ValidationReport
/// @brief Compare two tensors and record the results in the report.
///
/// This is the main function used to compare two tensors. The results of this
/// This is one of the main function used to compare two tensors. The results of this
/// comparison, including any supplemental information, is recorded into the report.
///
/// @returns `false` if the comparison failed. If so, the details can be found via
@@ -111,8 +112,45 @@ struct ValidationReport
const TensorDescriptor<DT, RANK>& descriptor,
const void* actual,
const void* expected,
double rtol = 1e-3,
double atol = 1e-3);
float rtol = 1e-3f,
float atol = 1e-3f);
/// @brief Compare two tensors and record the results in the report, with automatic
/// computation of tolerances.
///
/// This variant computes the tolerances automatically based on the compute
/// (accumulation) type, and the number of accumulations required per result value.
/// This is one of the main function used to compare two tensors. The results of this
/// comparison, including any supplemental information, is recorded into the report.
/// @returns `false` if the comparison failed. If so, the details can be found via
/// `get_errors()`.
///
/// @tparam OutDataType The data type of the tensors to check. This is the type of the
/// values in tensor memory.
/// @tparam ComputeType The data type that tensor operations are computed with internally.
/// @tparam AccType The data type that tensor values are accumulated with internally.
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to check.
///
/// @param tensor_name The name of the tensors to check. This should be a value by which
/// whoever is debugging the associated test later can easily find out which of the
/// outputs of a device operation was incorrect.
/// @param descriptor The descriptor (memory layout) of the tensor.
/// @param actual The device buffer with the values of the tensor to-be-tested, ie, the
/// results of the device operation.
/// @param expected The device buffer with the values of the reference tensor. These are
/// treated as a "golden standard", and should usually be generated by a reference
/// implementation.
/// @param number_of_accumulations The maximum number of accumulations required to compute
/// a value of the result tensor.
template <DataType OutDataType,
DataType ComputeType = OutDataType,
DataType AccType = ComputeType,
size_t RANK>
bool check_by_accumulations(std::string_view tensor_name,
const TensorDescriptor<OutDataType, RANK>& descriptor,
const void* actual,
const void* expected,
const size_t number_of_accumulations);
private:
std::vector<Case> reports_;
@@ -121,89 +159,58 @@ struct ValidationReport
template <DataType DT, size_t RANK>
bool ValidationReport::check(std::string_view tensor_name,
const TensorDescriptor<DT, RANK>& descriptor,
const void* actual_data,
const void* expected_data,
double rtol,
double atol)
const void* actual,
const void* expected,
float rtol,
float atol)
{
const auto strides = descriptor.get_strides();
using CKType = detail::cpp_type_t<DT>;
// During development and CI, only the kernels that were changed would fail, and so we can
// assume that the average case does not have errors. Therefore, split out testing into a
// quick test which just counts the incorrect elements, and a more in-depth test that also
// returns the indices of the incorrect items.
const auto a_it = FlatTensorIterator(descriptor, static_cast<const CKType*>(actual));
const auto e_it = FlatTensorIterator(descriptor, static_cast<const CKType*>(expected));
const auto numel = a_it.numel();
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
const auto* actual = static_cast<const CKType*>(actual_data);
const auto* expected = static_cast<const CKType*>(expected_data);
static_assert(!std::is_same_v<CKType, double>,
"TODO implement compare_kernel() for double");
const auto offset = calculate_offset(index, strides);
const auto a = actual[offset];
const auto b = expected[offset];
const auto o = static_cast<double>(type_convert<float>(a));
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
atomicMax(d_max_error, err);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
// for now.
atomicAdd(d_error_count, 1);
}
// Now compare the numbers as bitwise too.
// Update the counter if they're both zero.
using Bytes = std::array<std::byte, sizeof(CKType)>;
bool all_zero = true;
for(auto x : std::bit_cast<Bytes>(a))
{
if(x != std::byte{0})
all_zero = false;
}
for(auto x : std::bit_cast<Bytes>(b))
{
if(x != std::byte{0})
all_zero = false;
}
if(all_zero)
{
atomicAdd(d_zero_count, 1);
}
});
uint64_t error_count = 0;
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
double max_error = 0;
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
const auto result = ck::profiler::gpu_verify<CKType>(a_it, e_it, rtol, atol, numel);
// TODO: Gather detailed coordinates.
reports_.push_back(Case{
.tensor_name = std::string(tensor_name),
.wrong_elements = error_count,
.wrong_elements = result.error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
.max_error = max_error,
.both_all_zero = result.all_zero,
.max_error = result.max_error,
});
return reports_.back().is_ok();
}
template <DataType OutDataType, DataType ComputeType, DataType AccType, size_t RANK>
bool ValidationReport::check_by_accumulations(std::string_view tensor_name,
const TensorDescriptor<OutDataType, RANK>& descriptor,
const void* actual,
const void* expected,
const size_t number_of_accumulations)
{
using CKComputeType = detail::cpp_type_t<ComputeType>;
using CKAccType = detail::cpp_type_t<AccType>;
using CKOutDataType = detail::cpp_type_t<OutDataType>;
const auto a_it = FlatTensorIterator(descriptor, static_cast<const CKOutDataType*>(actual));
const auto e_it = FlatTensorIterator(descriptor, static_cast<const CKOutDataType*>(expected));
const auto numel = a_it.numel();
const auto result = ck::profiler::gpu_verify<CKOutDataType, CKComputeType, CKAccType>(
a_it, e_it, static_cast<int>(number_of_accumulations), numel);
// TODO: Gather detailed coordinates.
reports_.push_back(Case{
.tensor_name = std::string(tensor_name),
.wrong_elements = result.error_count,
.total_elements = descriptor.get_element_size(),
.both_all_zero = result.all_zero,
.max_error = result.max_error,
});
return reports_.back().is_ok();

File diff suppressed because it is too large Load Diff

View File

@@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Instance = ckb::ConvBuilder<SIGNATURE, ALGORITHM>::Instance;
EXPECT_THAT(
ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNHWC\n"
"│ ├─ Weight Layout: GKYXC\n"
"│ ├─ Output Layout: GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ├─ Vector access (GMEM write) instruction size: 2\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}
// Test printing of optional parameters num_groups_to_merge,
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest)
{
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // InLayout
ck::tensor_layout::convolution::GKYXC, // WeiLayout
ck::tensor_layout::convolution::GNHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
Default, // ConvBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // K0PerBlock
8, // AK1
32, // MPerWMMA
32, // NPerXDL
4, // MRepeat
4, // NRepeat
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
1, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
1, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
8, // CDEBlockTransferScalarPerVector_NPerBlock_
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
4, // NumGroupsToMerge
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
1, // MaxTransposeTransferSrcScalarPerVector
1>; // MaxTransposeTransferDstScalarPerVector>
EXPECT_THAT(ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"2D Backward Weight Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNHWC\n"
@@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Data tile size: 128×128×16\n"
" ├─ Struct does not contain optional gemm_padding argument\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Pipeline version: V1\n"
" ├─ Pipeline scheduler: DEFAULT\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 8×8\n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 2\n"
" │ ├─ Vector access (LDS write) instruction size: 2\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ─ Vector access (GMEM write) instruction size: 2"));
" ─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n"
" ├─ Max Transpose transfer scr scalar per vector: 1\n"
" ├─ Max Transpose dst scalar per vector: 1\n"
" └─ Num groups to merge: 4"));
}
// Test printing of optional parameters num_groups_to_merge,
// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector
TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest)
{
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
3, // NDimSpatial
ck::tensor_layout::convolution::GNDHWC, // InLayout
ck::tensor_layout::convolution::GKZYXC, // WeiLayout
ck::tensor_layout::convolution::GNDHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
Default, // ConvBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // K0PerBlock
8, // K1
32, // MPerWmma
32, // NPerWmma
4, // MRepeat
4, // NRepeat
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
1, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
1, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_
8, // CDEBlockTransferScalarPerVector_NPerBlock_
1, // NummGemmKPrefetchStage
ck::LoopScheduler::Default, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
false>; // BComputeDataType
EXPECT_THAT(
ckr::describe<Instance>().detailed(),
ckt::StringEqWithDiff( //
"3D Backward Weight Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Input Layout: GNDHWC\n"
"│ ├─ Weight Layout: GKZYXC\n"
"│ ├─ Output Layout: GNDHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"└─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 128×128×16\n"
" ├─ Struct does not contain optional gemm_padding argument\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V1\n"
" ├─ Pipeline scheduler: DEFAULT\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 32×32\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" └─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 2×128×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 1×0×2\n"
" │ ├─ The order of accessing data tile axes: 1×0×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" ├─ Vector access (GMEM write) instruction size: 8\n"
" ├─ Num gemm k prefetch stage: 1\n"
" ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector "
"parameter\n"
" ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n"
" └─ Struct does not contain optional num_groups_to_merge parameter"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)

View File

@@ -209,7 +209,8 @@ struct ReferenceOutputMatcher
// Round to 2 digits
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
*listener << e.wrong_elements << "/" << e.total_elements
<< " incorrect elements (~" << percentage << "%)";
<< " incorrect elements (~" << percentage << "%)," << " max error "
<< e.max_error;
}
}
}

View File

@@ -98,8 +98,10 @@ TEST(ConvFwdTesting, Validate)
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
ckt::clear_tensor_buffer(
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
@@ -115,8 +117,10 @@ TEST(ConvFwdTesting, Validate)
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
++field_count;
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
ckt::clear_tensor_buffer(
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(2));
ckt::clear_tensor_buffer(
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(1));
});
const auto report = ckt::validate(ARGS, a.get(), b.get());

View File

@@ -225,3 +225,99 @@ TEST(TensorForeach, ClearTensorZeros)
EXPECT_THAT(actual, Eq(0));
}
TEST(TensorForeach, CopyTensor)
{
constexpr auto dt = ckb::DataType::I32;
const ckt::Extent shape = {10, 3, 45, 23, 6};
using Counter = uint32_t;
const auto src_desc = ckt::make_descriptor<dt>(shape, ckt::PackedRightLayout{});
const auto dst_desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
auto src_buffer = ckt::alloc_tensor_buffer(src_desc);
auto dst_buffer = ckt::alloc_tensor_buffer(dst_desc);
const auto gen = [](const auto& index, const auto& lengths) {
// Simple incrementing counter
return static_cast<Counter>(ckt::calculate_offset(index, lengths));
};
ckt::fill_tensor(
src_desc, src_buffer.get(), [lengths = src_desc.get_lengths(), gen](const auto& index) {
return gen(index, lengths);
});
ckt::clear_tensor_buffer(dst_desc, dst_buffer.get());
// Perform the actual test
ckt::copy_tensor(src_desc, src_buffer.get(), dst_desc, dst_buffer.get());
// Check that the dst tensor has the same data
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
ckt::tensor_foreach(shape,
[lengths = dst_desc.get_lengths(),
gen,
dst = dst_buffer.get(),
invalid = reinterpret_cast<Counter*>(d_invalid.get()),
strides = dst_desc.get_strides()](const auto& index) {
const auto offset = ckt::calculate_offset(index, strides);
const auto expected = gen(index, lengths);
const auto actual = reinterpret_cast<const Counter*>(dst)[offset];
if(expected != actual)
atomicAdd(invalid, 1);
});
Counter invalid = 0;
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
EXPECT_THAT(invalid, Eq(0));
}
TEST(TensorForeach, FlatTensorIterator)
{
using Counter = uint32_t;
constexpr auto dt = ckb::DataType::I32;
const ckt::Extent shape = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
const ckt::Extent packed_strides = ckt::PackedRightLayout{}(shape);
const auto desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
auto buffer = ckt::alloc_tensor_buffer(desc);
// Fill the tensor with random values according to the *flat* index. The
// FlatTensorIterator iterates over flat values even if the strides are not
// packed, so indexing these elements according to the flat index in the
// iterator should yield again this value.
ckt::fill_tensor(desc, buffer.get(), [packed_strides](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, packed_strides);
return static_cast<int32_t>(flat_index * 10001 % 1001);
});
auto iterator = ckt::FlatTensorIterator(desc, reinterpret_cast<const int32_t*>(buffer.get()));
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
ckt::tensor_foreach(shape,
[iterator,
packed_strides,
strides = desc.get_strides(),
data = reinterpret_cast<const int32_t*>(buffer.get()),
invalid = reinterpret_cast<Counter*>(d_invalid.get())](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, packed_strides);
const auto offset = ckt::calculate_offset(index, strides);
if(iterator[flat_index] != data[offset])
atomicAdd(invalid, 1);
});
Counter invalid = 0;
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
EXPECT_THAT(invalid, Eq(0));
}

View File

@@ -74,7 +74,8 @@ TYPED_TEST(ValidationReportTests, SingleCorrect)
ckt::fill_tensor(desc, b.get(), generator);
ckt::ValidationReport report;
report.check("correct", desc, b.get(), a.get());
report.check("correct - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations("correct - implicit tolerance", desc, b.get(), a.get(), 0);
EXPECT_THAT(report.get_errors().size(), Eq(0));
}
@@ -97,17 +98,22 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect)
});
ckt::ValidationReport report;
report.check("incorrect", desc, b.get(), a.get());
report.check("incorrect - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations("incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
const auto errors = report.get_errors();
const auto flat_size = desc.get_element_size();
const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1;
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect - explicit tolerance"));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect - implicit tolerance"));
for(int i = 0; i < 2; ++i)
{
EXPECT_THAT(errors[i].wrong_elements, Eq(expected_errors));
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
}
}
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
@@ -121,14 +127,20 @@ TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
ckt::clear_tensor_buffer(desc, b.get());
ckt::ValidationReport report;
report.check("zero_is_incorrect", desc, b.get(), a.get());
report.check("zero_is_incorrect - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations(
"zero_is_incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect - explicit tolerance"));
EXPECT_THAT(errors[1].tensor_name, StrEq("zero_is_incorrect - implicit tolerance"));
for(int i = 0; i < 2; ++i)
{
EXPECT_THAT(errors[i].wrong_elements, Eq(0));
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[i].both_all_zero, Eq(true));
}
}
TEST(ValidationReportTests, MultipleSomeIncorrect)
@@ -143,11 +155,12 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
auto b = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 100); });
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 100)); });
ckt::fill_tensor_buffer(
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 101); });
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 101)); });
report.check("incorrect 1", desc, b.get(), a.get());
report.check("incorrect 1 - explicit tolerance", desc, b.get(), a.get());
report.check("incorrect 1 - implicit tolerance", desc, b.get(), a.get(), 0);
}
{
@@ -169,7 +182,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
}
});
report.check("correct", desc, b.get(), a.get());
report.check("correct - explicit tolerance", desc, b.get(), a.get());
report.check("correct - implicit tolerance", desc, b.get(), a.get(), 0);
}
{
@@ -182,16 +196,21 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; });
ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; });
report.check("incorrect 2", desc, b.get(), a.get());
report.check("incorrect 2 - explicit tolerance", desc, b.get(), a.get());
report.check("incorrect 2 - implicit tolerance", desc, b.get(), a.get(), 0);
}
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1"));
ASSERT_THAT(errors.size(), Eq(4));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1 - explicit tolerance"));
EXPECT_THAT(errors[0].wrong_elements, Eq(46840334));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2"));
EXPECT_THAT(errors[1].wrong_elements, Eq(482800));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 1 - implicit tolerance"));
EXPECT_THAT(errors[1].wrong_elements, Eq(46840334));
EXPECT_THAT(errors[2].tensor_name, StrEq("incorrect 2 - explicit tolerance"));
EXPECT_THAT(errors[2].wrong_elements, Eq(482800));
EXPECT_THAT(errors[3].tensor_name, StrEq("incorrect 2 - implicit tolerance"));
EXPECT_THAT(errors[3].wrong_elements, Eq(482800));
}
// MatchesReference operates on the types defined in testing.hpp, so just
@@ -234,7 +253,7 @@ ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
{
ValidationReport report;
report.check("a", args.make_a_descriptor(), actual.a, expected.a);
report.check("b", args.make_b_descriptor(), actual.b, expected.b);
report.check_by_accumulations("b", args.make_b_descriptor(), actual.b, expected.b, 0);
return report;
}
@@ -299,5 +318,5 @@ TEST(MatchesReference, Incorrect)
EXPECT_THAT(listener.str(),
StringEqWithDiff( //
"1 tensors failed to validate\n"
" - a: 625/625 incorrect elements (~100%)"));
" - a: 625/625 incorrect elements (~100%), max error 1"));
}

View File

@@ -67,8 +67,12 @@ __global__ void fill_tensor_uniform_rand_int_values(T* p,
}
else
{
p[i] = ck::type_convert<T, int>(
static_cast<int>((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value);
const auto value =
static_cast<int>((ran_gen_round_u32(s)) % (max_value - min_value)) + min_value;
if constexpr(std::is_integral_v<T> && !std::is_same_v<T, ck::bhalf_t>)
p[i] = ck::type_convert<T, int>(value);
else
p[i] = ck::type_convert<T, float>(value);
}
}
}

View File

@@ -5,10 +5,15 @@
#include <iomanip>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <cmath>
#include <array>
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -106,6 +111,102 @@ inline float compute_relative_tolerance(const int number_of_accumulations = 1)
}
}
/// @brief Turn an iterator type into an iterator that can be dereferenced.
///
/// In gpu_verify and gpu_reduce_max, it is valid to pass a void pointer and
/// have the function automatically derive the "concrete" pointer type to
/// be used in the kernel. This function does that: depending on whether
/// the `Iterator` is a void pointer or not, it returns either the iterator
/// (assuming that it is already concrete), or returns the pointer casted
/// to the concrete type.
///
/// @tparam T The value type of the pointer, when dereferenced.
/// @tparam Iterator The abstract iterator, can be void* or an actual pointer.
///
/// @param it The iterator to make concrete.
template <typename T, typename Iterator>
__device__ Iterator make_concrete_iterator(Iterator it)
{
return it;
}
template <typename T>
__device__ const T* make_concrete_iterator(const void* it)
{
return reinterpret_cast<const T*>(it);
}
template <typename T>
__device__ const T* make_concrete_iterator(void* it)
{
return reinterpret_cast<const T*>(it);
}
/// @brief Utility to launch persistent kernels.
///
/// This function launches a GPU kernel with a grid size derived from the kernel's
/// occupancy and the total number of multiprocessors on the GPU.
///
/// @tparam Kernel The type of the kernel function.
/// @tparam Args The types of the kernel arguments.
///
/// @param kernel An instance of the kernel function. This should be a __global__ function.
/// @param block_size The kernel's (1D) block size.
/// @param stream The stream to launch the kernel on.
/// @param args The kernel launch arguments.
template <typename Kernel, typename... Args>
void launch_persistent_kernel(Kernel kernel,
int block_size,
hipStream_t stream,
const Args&... args)
{
int occupancy;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, block_size, 0));
int device;
hip_check_error(hipGetDevice(&device));
int multiprocessors;
hip_check_error(
hipDeviceGetAttribute(&multiprocessors, hipDeviceAttributeMultiprocessorCount, device));
kernel<<<occupancy * multiprocessors, block_size, 0, stream>>>(args...);
hip_check_error(hipGetLastError());
}
/// @brief Simple block reduce kernel.
///
/// This function reduces all `value`s across a block according to `reduce`. This function
/// is a relatively simple implementation as its primary purpose is to be correct and
/// readable: No special cases are done for warp reductions, and the function allocates
/// its own shared memory. The result is broadcasted to all threads.
///
/// @tparam BlockSize The number of threads in a block.
/// @tparam T The value type to reduce over.
/// @tparam F The reduction functor type.
///
/// @param value This thread's value to reduce over.
/// @param reduce The reduction functor, used to combine two values. Should be associative.
template <int BlockSize, typename T, typename F>
__device__ T block_reduce(const T& value, F reduce)
{
__shared__ T workspace[BlockSize];
workspace[threadIdx.x] = value;
__syncthreads();
for(unsigned int s = BlockSize / 2; s >= 1; s >>= 1)
{
if(threadIdx.x < s)
workspace[threadIdx.x] = reduce(workspace[threadIdx.x], workspace[threadIdx.x + s]);
__syncthreads();
}
return workspace[0];
}
// Device-side result structure for kernel output
// Packed into a single struct to minimize device memory allocations
struct GpuVerifyDeviceResult
@@ -113,121 +214,142 @@ struct GpuVerifyDeviceResult
unsigned long long error_count; // Number of errors found
float max_error; // Maximum error value
int all_zero; // 1 = device result is all zeros, 0 = has non-zero values
/// @brief Return the neutral element of a GpuVerifyDeviceResult
///
/// This function returns the "neutral element", the element which does nothing
/// when reduced with another with `reduce_results`. Good to be used as an
/// initial value.
__host__ __device__ static GpuVerifyDeviceResult identity()
{
GpuVerifyDeviceResult result;
result.error_count = 0; // No errors yet
result.max_error = 0.0f; // No error observed
result.all_zero = 1; // Start assuming all zeros (will be cleared if nonzero found)
return result;
}
};
/// @brief Combine two device verify results.
///
/// This function returns the "combined" version of two GpuVerifyDeviceResult values, which
/// adds the total amount of errors, sets the correct max error, and records whether
/// any of the values had any zeros.
__device__ GpuVerifyDeviceResult reduce_results(const GpuVerifyDeviceResult& a,
const GpuVerifyDeviceResult& b)
{
GpuVerifyDeviceResult result;
result.error_count = a.error_count + b.error_count;
result.max_error = std::max(a.max_error, b.max_error);
result.all_zero = a.all_zero & b.all_zero;
return result;
}
/// @brief Compare individual tensor elements.
///
/// This function is what actually does the comparison between two tensor
/// elements. The function returns a tuple of three elements.
/// - The absolute maximum difference.
/// - If the second value is set to false, it indicates either that the elements are not
/// equal according to the thresholds `rtol` and `atol`, or that either value is not
/// finite (NaN/Infinity). If set to true, the values are considered equal.
/// - If the third value is set to true, it indicates that both elements are bitwise
/// equal to zero.
template <typename T>
__device__ std::tuple<float, bool, bool>
compare_elements(const T& actual, const T& expected, const float rtol, const float atol)
{
static_assert(!std::is_same_v<T, double>, "TODO: implement compare_elements() for double");
const auto o = type_convert<float>(actual);
const auto r = type_convert<float>(expected);
const auto e = std::abs(o - r);
const auto inequal = e > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r);
using Bytes = std::array<std::byte, sizeof(T)>;
const auto o_bytes = *reinterpret_cast<const Bytes*>(&actual);
const auto r_bytes = *reinterpret_cast<const Bytes*>(&expected);
bool all_zero = true;
for(const auto x : o_bytes)
{
if(x != std::byte{0})
all_zero = false;
}
for(const auto x : r_bytes)
{
if(x != std::byte{0})
all_zero = false;
}
return std::make_tuple(e, inequal, all_zero);
}
// GPU verification kernel - compares device result against reference using relative and absolute
// tolerance. Tracks all errors (no early exit) to provide detailed error reporting.
//
// Uses LDS (shared memory) for block-level reduction to minimize atomic contention.
// This reduces atomic operations from O(errors) to O(blocks), providing massive speedup
// when there are many errors.
//
// Assumption: Block size is 256
template <typename T>
__global__ void gpu_verify_kernel(const T* __restrict__ device_result,
const T* __restrict__ reference_result,
float rtol,
float atol,
long long size,
GpuVerifyDeviceResult* result)
template <int BlockSize, typename T, typename IteratorA, typename IteratorB>
__global__ __launch_bounds__(BlockSize) //
void gpu_verify_kernel(IteratorA device_result_it,
IteratorB reference_result_it,
float rtol,
float atol,
long long size,
GpuVerifyDeviceResult* result)
{
constexpr int block_size = 256;
auto device_result = make_concrete_iterator<T>(device_result_it);
auto reference_result = make_concrete_iterator<T>(reference_result_it);
// Shared memory for block-level reduction
__shared__ unsigned long long shared_error_count[block_size];
__shared__ float shared_max_error[block_size];
__shared__ int shared_has_error[block_size];
__shared__ int shared_has_nonzero[block_size];
// Thread-local accumulators (in registers)
unsigned long long local_error_count = 0;
float local_max_error = 0.0f;
int local_has_error = 0;
int local_has_nonzero = 0;
auto local_result = GpuVerifyDeviceResult::identity();
// Grid-stride loop to handle any tensor size
long long idx = blockIdx.x * blockDim.x + threadIdx.x;
long long stride = blockDim.x * gridDim.x;
long long idx = blockIdx.x * BlockSize + threadIdx.x;
long long stride = BlockSize * gridDim.x;
for(long long i = idx; i < size; i += stride)
{
// Convert to float for comparison
float dev_val = type_convert<float>(device_result[i]);
float ref_val = type_convert<float>(reference_result[i]);
const auto [abs_diff, inequal, bitwise_zero] =
compare_elements(device_result[i], reference_result[i], rtol, atol);
// Check if device value is non-zero
if(dev_val != 0.0f)
{
local_has_nonzero = 1;
}
// Compute absolute difference
float abs_diff = fabsf(dev_val - ref_val);
// Check tolerance (matches CPU check_err logic: err > atol + rtol * abs(ref))
if(abs_diff > atol + rtol * fabsf(ref_val))
{
local_has_error = 1;
local_error_count++;
local_max_error = fmaxf(local_max_error, abs_diff);
}
local_result = reduce_results(local_result,
GpuVerifyDeviceResult{
static_cast<uint64_t>(inequal), // error_count
abs_diff, // max_error
bitwise_zero // all_zero
});
}
// Store thread-local results to shared memory
shared_error_count[threadIdx.x] = local_error_count;
shared_max_error[threadIdx.x] = local_max_error;
shared_has_error[threadIdx.x] = local_has_error;
shared_has_nonzero[threadIdx.x] = local_has_nonzero;
__syncthreads();
// Block-level reduction: 256 -> 128 -> 64 -> 32
for(unsigned int s = block_size / 2; s >= 32; s >>= 1)
{
if(threadIdx.x < s)
{
shared_error_count[threadIdx.x] += shared_error_count[threadIdx.x + s];
shared_max_error[threadIdx.x] =
fmaxf(shared_max_error[threadIdx.x], shared_max_error[threadIdx.x + s]);
shared_has_error[threadIdx.x] |= shared_has_error[threadIdx.x + s];
shared_has_nonzero[threadIdx.x] |= shared_has_nonzero[threadIdx.x + s];
}
__syncthreads();
}
const auto block_result = block_reduce<BlockSize>(local_result, reduce_results);
// Final reduction of remaining 32 elements in thread 0
if(threadIdx.x == 0)
{
for(int i = 1; i < 32; ++i)
// Single atomic update per block (reduces contention from O(errors) to O(blocks))
if(block_result.error_count > 0)
{
shared_error_count[0] += shared_error_count[i];
shared_max_error[0] = fmaxf(shared_max_error[0], shared_max_error[i]);
shared_has_error[0] |= shared_has_error[i];
shared_has_nonzero[0] |= shared_has_nonzero[i];
atomicAdd(&result->error_count, block_result.error_count);
atomicMax(&result->max_error, block_result.max_error);
}
// Single atomic update per block (reduces contention from O(errors) to O(blocks))
if(shared_has_error[0])
if(!block_result.all_zero)
{
atomicAdd(&result->error_count, shared_error_count[0]);
atomicMax(&result->max_error, shared_max_error[0]);
}
// Update all_zero flag: if no nonzero values found, mark as all zero
if(!shared_has_nonzero[0])
{
atomicMin(&result->all_zero, 1);
}
else
{
atomicMin(&result->all_zero, 0);
// A nonzero was found, so set the global value to false.
// Note: this is a benign race condition; technically a race condition but
// all blocks write the same value, so its fine.
result->all_zero = 0;
}
}
}
// Host-side wrapper for GPU verification with explicit tolerances
// Returns GpuVerifyResult with detailed error information
template <typename T>
GpuVerifyResult gpu_verify(const void* device_result,
const void* reference_result,
template <typename T, typename IteratorA, typename IteratorB>
GpuVerifyResult gpu_verify(IteratorA device_result,
IteratorB reference_result,
float rtol,
float atol,
std::size_t size,
@@ -238,31 +360,25 @@ GpuVerifyResult gpu_verify(const void* device_result,
hip_check_error(hipMalloc(&result_dev, sizeof(GpuVerifyDeviceResult)));
// Initialize result struct
GpuVerifyDeviceResult result_host;
result_host.error_count = 0; // No errors yet
result_host.max_error = 0.0f; // No error observed
result_host.all_zero = 1; // Start assuming all zeros (will be cleared if nonzero found)
auto result_host = GpuVerifyDeviceResult::identity();
hip_check_error(
hipMemcpy(result_dev, &result_host, sizeof(GpuVerifyDeviceResult), hipMemcpyHostToDevice));
// Launch kernel with grid-stride loop
// Use 65535 as max grid size (hardware limit for grid dimension in x)
// Grid-stride loop handles any tensor size regardless of grid dimensions
// Launch persistent kernel.
// automatically derive the optimal grid size from the kernel's occupancy and the
// number of multiprocessors.
constexpr int block_size = 256;
int grid_size = std::min<int>(65535, (size + block_size - 1) / block_size);
const auto kernel = gpu_verify_kernel<block_size, T, IteratorA, IteratorB>;
gpu_verify_kernel<T>
<<<grid_size, block_size, 0, stream>>>(static_cast<const T*>(device_result),
static_cast<const T*>(reference_result),
rtol,
atol,
static_cast<long long>(size),
result_dev);
hip_check_error(hipGetLastError());
// Synchronize the stream to ensure kernel completion before reading results
hip_check_error(hipStreamSynchronize(stream));
launch_persistent_kernel(kernel,
block_size,
stream,
device_result,
reference_result,
rtol,
atol,
static_cast<long long>(size),
result_dev);
// Get result
hip_check_error(
@@ -276,23 +392,25 @@ GpuVerifyResult gpu_verify(const void* device_result,
result.error_count = result_host.error_count;
result.max_error = result_host.max_error;
result.total = size;
result.all_zero = (result_host.all_zero == 1);
result.all_zero = result_host.all_zero == 1;
return result;
}
// Forward declaration of gpu_reduce_max
template <typename T>
float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream = nullptr);
template <typename T, typename Iterator>
float gpu_reduce_max(Iterator device_buffer, std::size_t size, hipStream_t stream = nullptr);
// Host-side wrapper for GPU verification with automatic tolerance computation
// Computes max value on GPU, then computes tolerances and verifies
// Returns GpuVerifyResult with detailed error information
template <typename OutDataType,
typename ComputeDataType = OutDataType,
typename AccDataType = ComputeDataType>
GpuVerifyResult gpu_verify(const void* device_result,
const void* reference_result,
typename AccDataType = ComputeDataType,
typename IteratorA,
typename IteratorB>
GpuVerifyResult gpu_verify(IteratorA device_result,
IteratorB reference_result,
int number_of_accumulations,
std::size_t size,
hipStream_t stream = nullptr)
@@ -323,23 +441,26 @@ GpuVerifyResult gpu_verify(const void* device_result,
max_abs_value, number_of_accumulations));
}
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "verify: accumulations=" << number_of_accumulations << " rtol = " << rtol
<< " atol=" << atol << std::endl;
}
// Call the explicit tolerance version
return gpu_verify<OutDataType>(device_result, reference_result, rtol, atol, size, stream);
}
// GPU reduction kernel for computing max(abs(data))
// This is an internal kernel called only by gpu_reduce_max() wrapper.
//
// Assumption: Block size is 256
template <typename T>
__global__ void
gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restrict__ max_val)
template <int BlockSize, typename T, typename Iterator>
__global__ __launch_bounds__((BlockSize)) //
void gpu_reduce_max_kernel(Iterator it, long long size, float* __restrict__ max_val)
{
constexpr int block_size = 256;
__shared__ float shared_max[block_size];
auto data = make_concrete_iterator<T>(it);
long long idx = blockIdx.x * blockDim.x + threadIdx.x;
long long stride = blockDim.x * gridDim.x;
long long idx = blockIdx.x * BlockSize + threadIdx.x;
long long stride = BlockSize * gridDim.x;
float local_max = 0.0f;
@@ -349,37 +470,18 @@ gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restr
local_max = fmaxf(local_max, val);
}
shared_max[threadIdx.x] = local_max;
__syncthreads();
const auto block_max = block_reduce<BlockSize>(
local_max, [](const auto& a, const auto& b) { return std::max(a, b); });
// Block-level reduction: 256 -> 128 -> 64 -> 32
for(unsigned int s = block_size / 2; s >= 32; s >>= 1)
{
if(threadIdx.x < s)
{
shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]);
}
__syncthreads();
}
// Final reduction of remaining 32 elements in thread 0
if(threadIdx.x == 0)
{
for(int i = 1; i < 32; ++i)
{
shared_max[0] = fmaxf(shared_max[0], shared_max[i]);
}
// Single atomic update per block
atomicMax(max_val, shared_max[0]);
}
atomicMax(max_val, block_max);
}
// Host-side wrapper for GPU max reduction
// Computes max(abs(data)) and returns as float
// Only transfers 4 bytes (the final max value) instead of entire tensor
template <typename T>
float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream)
template <typename T, typename Iterator>
float gpu_reduce_max(Iterator device_buffer, std::size_t size, hipStream_t stream)
{
if(size == 0)
{
@@ -394,22 +496,14 @@ float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t st
float init_val = 0.0f;
hip_check_error(hipMemcpy(max_dev, &init_val, sizeof(float), hipMemcpyHostToDevice));
// Launch reduction kernel
// Use 1024 blocks max for reduction to balance occupancy vs. grid-stride iterations
// For very large tensors (>256M elements), grid-stride loop handles the remainder
// Launch persistent kernel.
// automatically derive the optimal grid size from the kernel's occupancy and the
// number of multiprocessors.
constexpr int block_size = 256;
int grid_size = std::min<int>(1024, (size + block_size - 1) / block_size);
const auto kernel = gpu_reduce_max_kernel<block_size, T, Iterator>;
gpu_reduce_max_kernel<T><<<grid_size, block_size, 0, stream>>>(
static_cast<const T*>(device_buffer), static_cast<long long>(size), max_dev);
hip_check_error(hipGetLastError());
// Synchronize if using default stream
if(stream == nullptr)
{
hip_check_error(hipDeviceSynchronize());
}
launch_persistent_kernel(
kernel, block_size, stream, device_buffer, static_cast<long long>(size), max_dev);
// Copy result to host (only 4 bytes!)
float max_host;

View File

@@ -11,32 +11,37 @@
#include "ck/utility/common_header.hpp"
#include "ck/ck.hpp"
template <typename inType, typename outType>
void convertTypeFromDevice(std::vector<inType>& fromDevice,
std::vector<outType>& res,
template <typename InType, typename OutType>
void convertTypeFromDevice(std::vector<InType>& fromDevice,
std::vector<OutType>& res,
uint64_t num_elements)
{
for(uint64_t i = 0; i < num_elements / ck::packed_size_v<inType>; i++)
for(uint64_t i = 0; i < num_elements / ck::packed_size_v<InType>; i++)
{
// since the CPU dosen't have non-standard data types, we need to convert to float
if constexpr(ck::is_same_v<ck::remove_cvref_t<inType>, ck::f4x2_pk_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<InType>, ck::f4x2_pk_t>)
{
ck::float2_t tmp = ck::type_convert<ck::float2_t, ck::f4x2_t>(fromDevice[i]);
res[i * 2] = tmp.x;
res[i * 2 + 1] = tmp.y;
}
else if constexpr(ck::is_same_v<ck::remove_cvref_t<inType>, ck::pk_i4_t>)
else if constexpr(ck::is_same_v<ck::remove_cvref_t<InType>, ck::pk_i4_t>)
{
uint8_t packed = fromDevice[i].data;
int hi = (packed >> 4) & 0x0f;
int lo = packed & 0x0f;
res[i * 2] = static_cast<outType>(hi - 8);
res[i * 2 + 1] = static_cast<outType>(lo - 8);
res[i * 2] = static_cast<OutType>(hi - 8);
res[i * 2 + 1] = static_cast<OutType>(lo - 8);
}
else if constexpr(ck::is_same_v<InType, ck::bhalf_t>)
{
res[i] = ck::type_convert<OutType, float>(
ck::type_convert<float, ck::bhalf_t>(fromDevice[i]));
}
else
{
res[i] = ck::type_convert<outType, inType>(fromDevice[i]);
res[i] = ck::type_convert<OutType, InType>(fromDevice[i]);
}
}
}
@@ -198,12 +203,13 @@ void TDevRanNormGenFp(double sigma,
}
TEST(TDevIntegerRanUniGen, U8) { TDevRanUniGenInt<uint8_t>(0, 2, 15000); }
TEST(TDevIntegerRanUniGen, U16) { TDevRanUniGenInt<uint16_t>(0, 100, 100000); }
// Note: U16 conflicts with ck::bhalf_t
TEST(TDevIntegerRanUniGen, U32) { TDevRanUniGenInt<uint32_t>(0, 10000, 10000000); }
TEST(TDevIntegerRanUniGen, I4) { TDevRanUniGenInt<ck::pk_i4_t>(-2, 2, 10000000); }
TEST(TDevIntegerRanUniGen, F32) { TDevRanUniGenInt<float>(-2, 2, 10000000); }
TEST(TDevIntegerRanUniGen, F16) { TDevRanUniGenInt<ck::half_t>(-2, 2, 1000000); }
TEST(TDevIntegerRanUniGen, BF16) { TDevRanUniGenInt<ck::bhalf_t>(-2, 2, 1000000); }
TEST(TDevFpRanUniGen, F32_1) { TDevRanUniGenFp<float>(0, 1, 100000); }
TEST(TDevFpRanUniGen, F32_2) { TDevRanUniGenFp<float>(0, 37, 73000); }

View File

@@ -83,7 +83,7 @@ class GPUVerificationTest : public ::testing::Test
// Use test fixture's RNG (rng_) for reproducibility
// RNG is seeded in SetUp() with fixed seed or CK_TEST_SEED environment variable
if constexpr(std::is_integral<T>::value)
if constexpr(std::is_integral_v<T> && !std::is_same_v<T, ck::bhalf_t>)
{
std::uniform_int_distribution<int> dis(static_cast<int>(min_val),
static_cast<int>(max_val));