Merge branch 'develop' into ck-tile-docs

This commit is contained in:
Thomas Ning
2025-11-12 10:42:30 -08:00
committed by GitHub
262 changed files with 11842 additions and 1667 deletions

View File

@@ -53,8 +53,8 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit
path: "TheRock"
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: Setup ccache
run: |
@@ -77,6 +77,8 @@ jobs:
- name: Patch rocm-libraries
run: |
git config --global --add safe.directory '*'
# Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo
rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
- name: Install python deps
@@ -128,7 +130,7 @@ jobs:
run: |
python3 TheRock/build_tools/github_actions/post_build_upload.py \
--run-id ${{ github.run_id }} \
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
--artifact-group ${{ env.AMDGPU_FAMILIES }} \
--build-dir TheRock/build \
--upload

View File

@@ -51,13 +51,13 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: "ROCm/TheRock"
ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: Run setup test environment workflow
uses: './.github/actions/setup_test_environment'
with:
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
ARTIFACT_GROUP: ${{ inputs.amdgpu_families }}
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
VENV_DIR: ${{ env.VENV_DIR }}
FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }}

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: "Configuring CI options"
env:

6
.gitignore vendored
View File

@@ -66,6 +66,12 @@ docs/doxygen/xml
cmake-build*/
build*/
# LSP configuration
.clangd
# User-defined CMake presets
CMakeUserPresets.json
# Python virtualenv
.venv/

View File

@@ -683,6 +683,12 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
PACKAGE_NAME examples
)
add_subdirectory(example)
add_subdirectory(tutorial)
rocm_package_setup_component(tutorials
LIBRARY_NAME composablekernel
PACKAGE_NAME tutorials
)
add_subdirectory(tile_engine)
if(BUILD_TESTING)
add_subdirectory(test)

3
Jenkinsfile vendored
View File

@@ -1836,10 +1836,11 @@ pipeline {
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCK_CXX_STANDARD="17" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx90a" \
-DCK_CXX_STANDARD="17" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """

View File

@@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
using ALayout = Row;
using ALayout = Col;
using BLayout = Col;
using CLayout = Row;
@@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64, 64,
8, 8,
16, 16, // AK1, BK1
16, 16,
4, 2,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 4, 16, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
2, 16, 16, 0,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
ComputeTypeA, ComputeTypeB>;

View File

@@ -221,8 +221,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
b1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));

View File

@@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 2,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t N = 7168;
ck::index_t K = 256;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t tokens = 208;
ck::index_t topk = 2;
if(argc == 1)

View File

@@ -309,8 +309,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f}(b_k_n);
}
else if(init_method == 1)
{

View File

@@ -14,28 +14,14 @@
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
static constexpr ck_tile::index_t NumGroupsToMerge = 1;
};
@@ -216,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_backward_data_invoker.hpp"
#include "run_grouped_convolution_bwd_data_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionBackwardDataInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -22,64 +22,59 @@ struct GroupedConvolutionBackwardDataInvoker
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups>;
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -87,102 +82,103 @@ struct GroupedConvolutionBackwardDataInvoker
1,
std::multiplies<ck_tile::index_t>());
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{

View File

@@ -21,48 +21,42 @@ struct GroupedConvolutionBackwardWeightInvoker
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -70,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -102,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -128,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -137,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -185,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -23,47 +23,42 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
using WorkspaceDataType = float;
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -71,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -103,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -129,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -140,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -236,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return ave_time;
};

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_bias_clamp_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -31,14 +31,14 @@ int run_grouped_conv_fwd_bias_clamp_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_bias_clamp_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -14,7 +14,7 @@
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -32,68 +32,60 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
CDElementWise>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -102,8 +94,8 @@ struct GroupedConvolutionForwardInvoker
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -112,89 +104,88 @@ struct GroupedConvolutionForwardInvoker
// =====================================================================
// Regular Convolution: Simple, no split-image
// =====================================================================
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return ave_time;
};
// =====================================================================
// Split-K lambda
@@ -202,11 +193,11 @@ struct GroupedConvolutionForwardInvoker
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};

View File

@@ -19,7 +19,7 @@
#include "grouped_convolution_forward_large_tensor_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
@@ -36,14 +36,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}

View File

@@ -7,7 +7,7 @@
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
@@ -25,65 +25,75 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using GroupedConvTraitsTypeDefault =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge,
true /*EnableSplitImage*/>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
@@ -92,8 +102,8 @@ struct GroupedConvolutionForwardInvoker
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -102,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
VectorSizeA,
VectorSizeB,
VectorSizeC,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB,
GroupedConvTraitsTypeDefault::VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
@@ -243,27 +253,31 @@ struct GroupedConvolutionForwardInvoker
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using GroupedConvTraitsType = std::conditional_t<EnableSplitImage,
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
@@ -272,24 +286,23 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
GroupedConvTraitsType,
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
@@ -351,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -11,7 +11,9 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "conv_configs.hpp"
using MemoryOpSet =
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_bwd_data<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -39,7 +39,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -141,7 +141,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -215,7 +215,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -225,7 +225,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -235,7 +235,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,

View File

@@ -3,7 +3,7 @@
#pragma once
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
@@ -17,7 +17,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
@@ -39,7 +39,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
}
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
@@ -141,7 +141,7 @@ int run_grouped_conv_fwd_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
@@ -193,7 +193,7 @@ int run_grouped_conv_fwd_example_with_layouts(
}
template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
@@ -215,7 +215,7 @@ int run_grouped_conv_fwd_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -225,7 +225,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
@@ -235,7 +235,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,

View File

@@ -2,6 +2,116 @@
This folder contains example for the pooling operator using ck_tile tile-programming implementation. Currently the pooling kernel only supports 2D and 3D pooling.
## Tensor Descriptor Transformations
The pooling kernel transforms the input tensor into 2D format suitable for reduction. This section explains the transformation pipeline for both 2D and 3D pooling operations.
### 3D Pooling Transformations
For 3D pooling, the input tensor has shape `(N, D, H, W, C)` where:
- `N`: batch size
- `D`: depth dimension
- `H`: height dimension
- `W`: width dimension
- `C`: channel dimension
The transformations convert this 5D tensor into a 2D tensor where rows represent output positions (M) and columns represent pooling window elements (K).
```mermaid
graph TD
%% Input Tensor: (N, D, H, W, C)
Input["Input Tensor<br/>(N, D, H, W, C)"]
style Input fill:#e1f5fe
%% Pass-through N dimension
PassN["Pass-through N<br/>(batch size)"]
style PassN fill:#f3e5f5
Input --> PassN
%% Pad spatial dimensions
PadD["Pad D<br/>(depth with left/right padding)"]
style PadD fill:#fff9c4
Input --> PadD
PadH["Pad H<br/>(height with left/right padding)"]
style PadH fill:#fff9c4
Input --> PadH
PadW["Pad W<br/>(width with left/right padding)"]
style PadW fill:#fff9c4
Input --> PadW
%% Pass-through C dimension
PassC["Pass-through C<br/>(channels)"]
style PassC fill:#f3e5f5
Input --> PassC
%% Embed sliding windows
EmbedD["Embed D<br/>window(Z) × output_positions(Dₒ)"]
style EmbedD fill:#fff3e0
PadD --> EmbedD
EmbedH["Embed H<br/>window(Y) × output_positions(Hₒ)"]
style EmbedH fill:#fff3e0
PadH --> EmbedH
EmbedW["Embed W<br/>window(X) × output_positions(Wₒ)"]
style EmbedW fill:#fff3e0
PadW --> EmbedW
%% Merge into 2D matrix
MergeM["Merge M<br/>(N, Dₒ, Hₒ, Wₒ, C)<br/>→ output positions"]
style MergeM fill:#e8f5e9
PassN --> MergeM
EmbedD --> MergeM
EmbedH --> MergeM
EmbedW --> MergeM
PassC --> MergeM
MergeK["Merge K<br/>(Z, Y, X)<br/>→ window elements"]
style MergeK fill:#e8f5e9
EmbedD --> MergeK
EmbedH --> MergeK
EmbedW --> MergeK
%% Final padding for block alignment
PadM["Right-pad M<br/>(for block alignment)"]
style PadM fill:#fff9c4
MergeM --> PadM
PadK["Right-pad K<br/>(for block alignment)"]
style PadK fill:#fff9c4
MergeK --> PadK
%% Result
Result["2D Matrix<br/>(M × K)"]
style Result fill:#c8e6c9
PadM --> Result
PadK --> Result
```
**Transformation Steps:**
1. **Padding**: Apply left and right padding to spatial dimensions (D, H, W) to handle boundary conditions
2. **Sliding Windows**: Use embed transforms to create sliding windows across each spatial dimension, expanding each dimension into (window_size, output_positions)
3. **Reshaping**: Merge all dimensions into a 2D matrix where:
- M dimension = N × Dₒ × Hₒ × Wₒ × C (total output positions)
- K dimension = Z × Y × X (elements per pooling window)
4. **Block Alignment**: Apply right padding to ensure M and K dimensions are aligned to block size
### 2D Pooling Transformations
2D pooling follows the same transformation pipeline but operates on 4D tensors with shape `(N, H, W, C)`. The process is identical except:
- Only H and W dimensions are padded and embedded
- K dimension merges only (Y, X) window elements
- M dimension merges (N, Hₒ, Wₒ, C)
### Output Tensor Transformations
The output tensor transformations are simpler:
- Merge all output dimensions (N, Dₒ/Hₒ, Wₒ, C) into a single M dimension
- Apply right padding for block alignment
- The result is a 1D tensor that maps directly to the M dimension of the computation matrix
## build
```
# in the root of ck_tile

View File

@@ -5,7 +5,7 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()

View File

@@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f)
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
#else
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
#endif
}

View File

@@ -12,12 +12,11 @@
#include "ck_tile/ops/gemm_quant.hpp"
#define CK_TILE_SUPPORTED_QUANT_GROUPS(X) \
X(1, 1, 64) /* 1D */ \
X(1, 1, 128) /* 1D */ \
X(1, 8, 128) /* 2D N=8 */ \
X(1, 32, 128) /* 2D N=32 */ \
X(1, 64, 128) /* 2D N=64 */ \
X(1, 128, 128) /* 2D N=128 */
X(1, 1, 64) /* 1D */ \
X(1, 1, 128) /* 1D */ \
X(1, 8, 128) /* 2D N=8 */ \
X(1, 32, 128) /* 2D N=32 */ \
X(1, 64, 128) /* 2D N=64 */
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -217,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,

View File

@@ -22,8 +22,8 @@ args:
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)

View File

@@ -18,7 +18,6 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -27,12 +26,12 @@ struct GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
template <typename PrecType, bool Persistent_>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 16;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool Persistent = Persistent_;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
@@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[])
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("num_sk_blocks",
"-1",
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/reduction")
.insert("persistent_dp",
"0",
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")

View File

@@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy,
num_sk_blocks};
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
@@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy,
num_sk_blocks);
reduction_strategy);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
@@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = true;
bool pass = false;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(

View File

@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfig,
@@ -17,9 +16,8 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
<< std::endl;
}
// Function to clear the output C tensor results after each repetition of the kernel
auto clear_gemm_output = [&]() {
auto reset_data_buffers = [&]() {
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
}
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};
std::function<void()> preprocess = clear_gemm_output;
std::function<void()> preprocess = reset_data_buffers;
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
kargs.tile_partitioner.sk_num_blocks,
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
// avoid division by zero errors.
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
kargs.tile_partitioner.k_iters_per_tile.get());
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{ave_time, num_wgs_per_tile};
};
@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
}
}
#include "run_gemm_example.inc"
template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
@@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
template <template <typename PreType> typename GemmConfig>
template <template <typename PreType, bool Persistent_> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_dp = arg_parser.get_bool("persistent_dp");
if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else
{

View File

@@ -25,7 +25,6 @@ add_subdirectory(22_gemm_multi_abd)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_pooling)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)
add_subdirectory(40_streamk_gemm)
add_subdirectory(41_batched_contraction)

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept for parameters that describe a gridwise WMMA GEMM problem.
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept for vectorized data transfer for convolution input tensors.
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
@@ -180,7 +180,90 @@ concept SpecifiesNumGroupsToMerge = requires {
template <typename T>
concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
// Concept for DL thread configuration
template <typename T>
concept DlThreadConfigDescriptor = requires(T t) {
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m1_per_thread } -> std::convertible_to<size_t>;
{ t.n1_per_thread } -> std::convertible_to<size_t>;
{ t.k_per_thread } -> std::convertible_to<size_t>;
};
// Concept for DL thread cluster
template <typename T>
concept DlThreadClusterDescriptor = requires(T t) {
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
};
// Concept for DL block transfer K0_M0_M1_K1 format
template <typename T>
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL block transfer K0_N0_N1_K1 format
template <typename T>
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL C thread transfer
template <typename T>
concept DlCThreadTransferDescriptor = requires(T t) {
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if algorithm specifies DL thread config
template <typename T>
concept SpecifiesDlThreadConfig = requires {
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
};
// Concept to check if algorithm specifies DL thread cluster
template <typename T>
concept SpecifiesDlThreadCluster = requires {
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
};
// Concept to check if algorithm specifies DL A block transfer
template <typename T>
concept SpecifiesDlBlockTransferA = requires {
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
};
// Concept to check if algorithm specifies DL B block transfer
template <typename T>
concept SpecifiesDlBlockTransferB = requires {
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
};
// Concept to check if algorithm specifies DL C thread transfer
template <typename T>
concept SpecifiesDlCThreadTransfer = requires {
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// A factory for instantiating CK convolution kernels.
//
@@ -36,9 +36,21 @@
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
#pragma push_macro("GridwiseGemmTemplateParameters")
#ifdef GridwiseGemmTemplateParameters
#undef GridwiseGemmTemplateParameters
#endif
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#pragma pop_macro("GridwiseGemmTemplateParameters")
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
@@ -297,42 +309,42 @@ constexpr BlockGemmSpec SetBlockGemm()
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
}
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
if constexpr(BG.pipeline_version == PipelineVersion::V1)
{
version = ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
{
version = ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
{
version = ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
{
version = ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
{
version = ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -442,17 +454,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
{
return ck::LoopScheduler::Default;
}
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
{
return ck::LoopScheduler::Interwave;
}
else
{
static_assert(false, "Unknown LoopScheduler");
static_assert(false, "Unknown PipelineScheduler");
}
}
@@ -460,29 +472,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
if constexpr(pipeline_version == PipelineVersion::V1)
{
return ck::PipelineVersion::v1;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
else if constexpr(pipeline_version == PipelineVersion::V2)
{
return ck::PipelineVersion::v2;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
else if constexpr(pipeline_version == PipelineVersion::V3)
{
static_assert(false, "V3 is used only for stream-K.");
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
else if constexpr(pipeline_version == PipelineVersion::V4)
{
return ck::PipelineVersion::v4;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
{
return ck::PipelineVersion::weight_only;
}
else
{
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -566,29 +578,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
if constexpr(version == PipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
else if constexpr(version == PipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
else if constexpr(version == PipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
else if constexpr(version == PipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
else if constexpr(version == PipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
static_assert(false, "Unknown PipelineVersion");
}
}
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
// of a grouped forward convolution kernel using Direct Load (DL) approach.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
"DL algorithm must specify thread config.");
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
"DL algorithm must specify thread cluster.");
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
"DL algorithm must specify A block transfer.");
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
"DL algorithm must specify B block transfer.");
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
"DL algorithm must specify C thread transfer.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
// DL-specific parameters from algorithm descriptor
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
// Thread cluster from descriptor
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
using ABlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
using BBlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
DL_C_TRANSFER.dst_scalar_per_vector;
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
SPATIAL_DIM,
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
};
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
// of a grouped forward convolution kernel with large tensor support (N-splitting).
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
"The convolution algorithm descriptor must specify number of prefetch stages.");
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
"The convolution algorithm descriptor must specify loop scheduler.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER =
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER =
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance with large tensor support.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
LOOP_SCHEDULER>;
};
} // namespace ck_tile::builder

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This file defines the compile-time "signature" for grouped convolution operations.
// A signature is a collection of properties that fully describe a convolution kernel's

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);

View File

@@ -0,0 +1,22 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::builder {
// Enumeration for CK Device Operation types.
// This allows the builder to select which device operation template to instantiate
// based on the user's requirements.
enum class DeviceOpType
{
// Forward Convolution - Non-grouped
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
// Forward Convolution - Grouped
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <variant>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
/// @file conv_description.hpp
/// @brief Provides human-readable descriptions of ConvBuilder configurations
namespace ck_tile::reflect::conv {
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
// Algorithm information - groups all algorithm-related configuration
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::GemmPadding padding;
};
// Provides human-readable descriptions of ConvBuilder configurations.
struct ConvDescription
{
ConvSignatureInfo signature;
GemmAlgorithmInfo algorithm;
// Brief one-line summary
std::string brief() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
return oss.str();
}
// Detailed hierarchical description
std::string detailed() const
{
TreeFormatter f;
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature.data_type);
f.writeLine(2, "Memory Layout: ", signature.layout);
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
f.writeLine(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm.tile_dims.m,
"×",
algorithm.tile_dims.n,
"×",
algorithm.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm.padding);
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm.warp_gemm.m_iter,
"×",
algorithm.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
algorithm.c_tile_transfer.scalar_per_vector);
f.writeLast(2);
f.writeLast(1);
return f.getString();
}
// Educational explanation of optimization choices
std::string explain() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
// Performance characteristics and use case guidance
std::string suggest() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
};
// Helper concept to detect if a type has InstanceTraits specialization
template <typename T>
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
// Helper concept to detect ConvBuilder types
template <typename T>
concept IsConvBuilder = requires {
typename T::Factory;
typename T::Instance;
};
// Primary factory function: Create ConvDescription from Instance type directly
template <typename Instance>
requires HasInstanceTraits<Instance>
ConvDescription Describe()
{
using Traits = ConvTraits<Instance>;
return ConvDescription{
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.layout = Traits::layout,
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op},
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding}};
}
// Backward compatibility: Create ConvDescription from Builder type
template <typename Builder>
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
ConvDescription Describe()
{
// Delegate to Instance-based version
using Instance = typename Builder::Instance;
return Describe<Instance>();
}
} // namespace ck_tile::reflect::conv

View File

@@ -0,0 +1,722 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/conv_factory.hpp>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/types.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/utility/loop_scheduler.hpp>
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include <ck_tile/ops/grouped_convolution.hpp>
namespace ck_tile::reflect::conv {
// Helper metafunctions to convert from ck enums to builder enums
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
/// @details This function maps CK's block GEMM pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. The pipeline version
/// determines the strategy used for data movement and computation overlap in the
/// GEMM kernel's main loop.
template <ck::BlockGemmPipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::BlockGemmPipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v3)
return V3;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == v5)
return V5;
}
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
/// @details This function maps CK's general pipeline version identifiers to the
/// builder framework's standardized pipeline version enum. Note that this overload
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
/// variant, including support for specialized weight-only pipelines.
template <ck::PipelineVersion ck_ver>
constexpr auto convert_pipeline_version()
{
using enum ck::PipelineVersion;
using enum builder::PipelineVersion;
if constexpr(ck_ver == v1)
return V1;
else if constexpr(ck_ver == v2)
return V2;
else if constexpr(ck_ver == v4)
return V4;
else if constexpr(ck_ver == weight_only)
return WEIGHT_ONLY;
}
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
/// builder framework's standardized scheduler enum. The scheduler determines how work
/// is distributed and synchronized within and across wavefronts during pipeline execution.
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
/// across multiple wavefronts.
template <ck::BlockGemmPipelineScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::BlockGemmPipelineScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Intrawave)
return INTRAWAVE;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
/// performance in certain scenarios.
template <ck::LoopScheduler ck_sched>
constexpr auto convert_pipeline_scheduler()
{
using enum ck::LoopScheduler;
using enum builder::PipelineScheduler;
if constexpr(ck_sched == Default)
return DEFAULT;
else if constexpr(ck_sched == Interwave)
return INTERWAVE;
}
/// @brief Helper structures for organizing trait data with domain-specific naming
/// @brief Data tile dimensions processed by a workgroup.
/// @details This struct defines the M, N, and K dimensions of the data tile
/// that a single workgroup (thread block) is responsible for processing in the
/// underlying GEMM computation.
struct DataTileInfo
{
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
};
/// @brief Dimensions for an input data tile transfer.
/// @details Defines the shape of the input tile (A or B matrix) as it is
/// transferred from global memory to LDS. The tile is conceptually divided
/// into k0 and k1 dimensions.
struct InputTileTransferDimensions
{
int k0; ///< The outer dimension of K, where K = k0 * k1.
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
///< memory.
};
/// @brief Parameters governing the transfer of an input tile.
/// @details This struct holds configuration details for how an input tile is
/// loaded from global memory into LDS, including thread clustering, memory
/// access patterns, and vectorization settings.
struct InputTileTransferParams
{
int k1; ///< The inner K dimension size, often matching the vectorization width.
std::array<int, 3>
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
///< many threads are arranged on each axis.
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
///< input tensor dimensions.
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
///< dimension to read first).
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
///< (the contiguous dimension).
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
///< elements accessed per thread per instruction.
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
///< dimension.
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
///< conflicts.
};
/// @brief Complete information for an input tile transfer.
/// @details Combines the dimensional information and transfer parameters for
/// a full description of an input tile's journey from global memory to LDS.
struct InputTileTransferInfo
{
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
};
/// @brief Parameters for the warp-level GEMM computation.
/// @details Defines the configuration of the GEMM operation performed by each
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
struct WarpGemmParams
{
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
///< wavefront (MXdlPerWave).
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
///< wavefront (NXdlPerWave).
};
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
/// @details Configures how many MFMA instruction results are processed per
/// wave in each iteration of the CShuffle routine.
struct WarpShuffleParams
{
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
///< per shuffle iteration.
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
///< per shuffle iteration.
};
/// @brief Information for the output tile transfer (CShuffle).
/// @details Describes how the final computed tile (C matrix) is written out from
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
struct OutputTileTransferInfo
{
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
///< data into the output tensor.
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
///< output tensor.
};
// Helper metafunctions to derive signature information from Instance types
/// @brief Derives the convolution direction from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
template <typename Instance>
constexpr builder::ConvDirection conv_direction()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
{
return builder::ConvDirection::FORWARD;
}
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
{
return builder::ConvDirection::BACKWARD_DATA;
}
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
{
return builder::ConvDirection::BACKWARD_WEIGHT;
}
else
{
return builder::ConvDirection::FORWARD; // Default fallback
}
}
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
if constexpr(InstTraits::kConvForwardSpecialization == Default)
{
return builder::ConvFwdSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
{
return builder::ConvFwdSpecialization::FILTER_3x3;
}
}
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
{
return builder::ConvBwdDataSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
}
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
{
return builder::ConvBwdWeightSpecialization::DEFAULT;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
{
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
}
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
{
return builder::ConvBwdWeightSpecialization::ODD_C;
}
}
}
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
template <typename Instance>
constexpr auto conv_layout()
{
using InstTraits = InstanceTraits<Instance>;
using ALayout = typename InstTraits::ALayout;
using BLayout = typename InstTraits::BLayout;
using ELayout = typename InstTraits::ELayout;
namespace ctc = ck::tensor_layout::convolution;
if constexpr(InstTraits::kSpatialDim == 1)
{
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
std::is_same_v<ELayout, ctc::GNWK>)
{
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
{
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
{
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
}
}
else if constexpr(InstTraits::kSpatialDim == 2)
{
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::GNHWK>)
{
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NHWGK>)
{
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKYXC> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
std::is_same_v<BLayout, ctc::GKCYX> &&
std::is_same_v<ELayout, ctc::NGKHW>)
{
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
}
}
else if constexpr(InstTraits::kSpatialDim == 3)
{
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::GNDHWK>)
{
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NDHWGK>)
{
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKZYXC> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
}
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
std::is_same_v<BLayout, ctc::GKCZYX> &&
std::is_same_v<ELayout, ctc::NGKDHW>)
{
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
}
}
}
/// @brief Derives the data type from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
template <typename Instance>
constexpr builder::DataType conv_data_type()
{
using InstTraits = InstanceTraits<Instance>;
using ADataType = typename InstTraits::ADataType;
if constexpr(std::is_same_v<ADataType, ck::half_t>)
{
return builder::DataType::FP16;
}
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
{
return builder::DataType::BF16;
}
else if constexpr(std::is_same_v<ADataType, float>)
{
return builder::DataType::FP32;
}
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
{
return builder::DataType::FP8;
}
else if constexpr(std::is_same_v<ADataType, int8_t>)
{
return builder::DataType::I8;
}
else if constexpr(std::is_same_v<ADataType, uint8_t>)
{
return builder::DataType::U8;
}
else
{
// Default fallback
return builder::DataType::FP32;
}
}
/// @brief Derives the elementwise operation from op type.
/// @tparam ElementwiseOp Elementwise operation functor type.
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
template <typename ElementwiseOp>
constexpr builder::ElementwiseOperation elementwise_op()
{
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
if constexpr(detail::case_insensitive_equal(name, "Bias"))
{
return builder::ElementwiseOperation::BIAS;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
{
return builder::ElementwiseOperation::BIAS_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
{
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
{
return builder::ElementwiseOperation::BILINEAR;
}
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
{
return builder::ElementwiseOperation::CLAMP;
}
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
{
return builder::ElementwiseOperation::SCALE;
}
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
{
return builder::ElementwiseOperation::PASS_THROUGH;
}
}
/// @brief Derives a gemm padding from a kernel instance type.
/// @tparam Instance - A Device Kernel object type.
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
template <typename Instance>
constexpr builder::GemmPadding gemm_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::GemmPadding;
using enum ck::tensor_operation::device::GemmSpecialization;
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
if constexpr(gemm_spec == Default)
{
return DEFAULT;
}
else if constexpr(gemm_spec == MPadding)
{
return M_PADDING;
}
else if constexpr(gemm_spec == NPadding)
{
return N_PADDING;
}
else if constexpr(gemm_spec == KPadding)
{
return K_PADDING;
}
else if constexpr(gemm_spec == MNPadding)
{
return MN_PADDING;
}
else if constexpr(gemm_spec == MKPadding)
{
return MK_PADDING;
}
else if constexpr(gemm_spec == NKPadding)
{
return NK_PADDING;
}
else if constexpr(gemm_spec == MNKPadding)
{
return MNK_PADDING;
}
else if constexpr(gemm_spec == OPadding)
{
return O_PADDING;
}
else if constexpr(gemm_spec == MOPadding)
{
return MO_PADDING;
}
else if constexpr(gemm_spec == NOPadding)
{
return NO_PADDING;
}
else if constexpr(gemm_spec == KOPadding)
{
return KO_PADDING;
}
else if constexpr(gemm_spec == MNOPadding)
{
return MNO_PADDING;
}
else if constexpr(gemm_spec == MKOPadding)
{
return MKO_PADDING;
}
else if constexpr(gemm_spec == NKOPadding)
{
return NKO_PADDING;
}
else if constexpr(gemm_spec == MNKOPadding)
{
return MNKO_PADDING;
}
}
/// @brief Primary template for extracting convolution traits.
/// @details This struct is the main entry point for reflecting on a convolution
/// kernel's properties. It is specialized to handle different kinds of input types.
template <typename T>
struct ConvTraits;
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
/// @details This is the primary specialization used to extract a comprehensive
/// set of traits directly from a fully-formed device kernel `Instance` type.
/// It uses `InstanceTraits` to access the kernel's template parameters.
template <typename Instance>
requires requires { typename InstanceTraits<Instance>; }
struct ConvTraits<Instance>
{
using InstTraits = InstanceTraits<Instance>;
// --- Signature Information ---
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
static constexpr int spatial_dim = InstTraits::kSpatialDim;
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
static constexpr auto layout = conv_layout<Instance>();
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
static constexpr builder::DataType data_type = conv_data_type<Instance>();
static constexpr builder::ElementwiseOperation input_element_op =
elementwise_op<typename InstTraits::AElementwiseOperation>();
static constexpr builder::ElementwiseOperation weight_element_op =
elementwise_op<typename InstTraits::BElementwiseOperation>();
static constexpr builder::ElementwiseOperation output_element_op =
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
/// @brief The GEMM specialization used by the kernel - padding
static constexpr auto gemm_padding = gemm_spec<Instance>();
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
static constexpr auto conv_specialization = conv_spec<Instance>();
// --- Algorithm Information ---
/// @brief The total number of threads in a thread block (workgroup).
static constexpr int thread_block_size = InstTraits::kBlockSize;
/// @brief The dimensions of the data tile processed by the thread block.
static constexpr DataTileInfo tile_dims = {
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
/// @brief Configuration for the A-matrix (input) tile transfer.
static constexpr InputTileTransferInfo a_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
.m_or_n = InstTraits::kMPerBlock,
.k1 = InstTraits::kAK1},
.transfer_params = {.k1 = InstTraits::kAK1,
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kABlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
/// @brief Configuration for the B-matrix (weights) tile transfer.
static constexpr InputTileTransferInfo b_tile_transfer = {
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
.m_or_n = InstTraits::kNPerBlock,
.k1 = InstTraits::kBK1},
.transfer_params = {.k1 = InstTraits::kBK1,
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
.dst_scalar_per_vector_k1 =
InstTraits::kBBlockTransferDstScalarPerVectorK1,
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
/// @brief Parameters for the warp-level GEMM computation.
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
.gemm_n = InstTraits::kNPerXDL,
.m_iter = InstTraits::kMXdlPerWave,
.n_iter = InstTraits::kNXdlPerWave};
/// @brief Configuration for the C-matrix (output) tile transfer.
static constexpr OutputTileTransferInfo c_tile_transfer = {
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
InstTraits::kCThreadClusterLengths[1],
InstTraits::kCThreadClusterLengths[2],
InstTraits::kCThreadClusterLengths[3]},
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
/// @brief Helper to safely get the pipeline version.
/// @details This is only available for some convolutions (e.g., forward).
/// If not present in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_version()
{
if constexpr(requires { T::kPipelineVersion; })
{
return convert_pipeline_version<T::kPipelineVersion>();
}
else
{
// Return a default or indicate not available
return builder::PipelineVersion::V1;
}
}
/// @brief The block GEMM pipeline version used by the kernel.
static constexpr auto pipeline_version = get_pipeline_version();
/// @brief Helper to safely get the pipeline scheduler.
/// @details This is only available for some convolutions. If not present
/// in `InstanceTraits`, it returns a default value.
template <typename T = InstTraits>
static constexpr auto get_pipeline_scheduler()
{
if constexpr(requires { T::kPipelineScheduler; })
{
return convert_pipeline_scheduler<T::kPipelineScheduler>();
}
else if constexpr(requires { T::kLoopScheduler; })
{
return convert_pipeline_scheduler<T::kLoopScheduler>();
}
else
{
// Return a default or indicate not available
return builder::PipelineScheduler::DEFAULT;
}
}
/// @brief The pipeline scheduler used by the kernel.
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
};
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
/// @details This specialization provides backward compatibility for reflecting
/// on kernels defined via the `ConvBuilder` interface. It works by first
/// creating the `Instance` via the builder's factory, and then delegating
/// all trait extraction to the `ConvTraits<Instance>` specialization.
template <builder::ConvSignatureDescriptor auto SIGNATURE,
builder::ConvAlgorithmDescriptor auto ALGORITHM,
builder::StringLiteral VERSION>
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
{
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
using Instance = typename Factory::Instance;
// Delegate to Instance-based ConvTraits
using InstanceConvTraits = ConvTraits<Instance>;
// Forward all members from Instance-based traits
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
static constexpr auto layout = InstanceConvTraits::layout;
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
static constexpr builder::ElementwiseOperation input_element_op =
InstanceConvTraits::input_element_op;
static constexpr builder::ElementwiseOperation weight_element_op =
InstanceConvTraits::weight_element_op;
static constexpr builder::ElementwiseOperation output_element_op =
InstanceConvTraits::output_element_op;
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
};
} // namespace ck_tile::reflect::conv

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Compile-time reflection for CK device kernel instances.
//
@@ -14,18 +14,9 @@
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <type_traits>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include "instance_traits_util.hpp"
#include <concepts>
namespace ck_tile::reflect {

View File

@@ -0,0 +1,286 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
// Forward declaration to avoid circular dependency
namespace ck::tensor_operation::device {
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA,
typename ComputeTypeB,
ck::index_t MaxTransposeTransferSrcScalarPerVector,
ck::index_t MaxTransposeTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffle;
} // namespace ck::tensor_operation::device
namespace ck_tile {
namespace reflect {
template <ck::index_t NDimSpatial,
typename InLayout_,
typename WeiLayout_,
typename OutLayout_,
typename InDataType_,
typename WeiDataType_,
typename OutDataType_,
typename AccDataType_,
typename InElementwiseOperation_,
typename WeiElementwiseOperation_,
typename OutElementwiseOperation_,
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
ConvBackwardWeightSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1_,
typename ABlockTransferThreadClusterArrangeOrder_,
typename ABlockTransferSrcAccessOrder_,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1_,
typename BBlockTransferThreadClusterArrangeOrder_,
typename BBlockTransferSrcAccessOrder_,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CShuffleMXdlPerWavePerShuffle,
ck::index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
ck::index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA_,
typename ComputeTypeB_,
ck::index_t MaxTransposeTransferSrcScalarPerVector,
ck::index_t MaxTransposeTransferDstScalarPerVector>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
InLayout_,
WeiLayout_,
OutLayout_,
InDataType_,
WeiDataType_,
OutDataType_,
AccDataType_,
InElementwiseOperation_,
WeiElementwiseOperation_,
OutElementwiseOperation_,
ConvBackwardWeightSpecialization,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1_,
ABlockTransferThreadClusterArrangeOrder_,
ABlockTransferSrcAccessOrder_,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1_,
BBlockTransferThreadClusterArrangeOrder_,
BBlockTransferSrcAccessOrder_,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
CBlockTransferScalarPerVector_NWaveNPerXdl,
ComputeTypeA_,
ComputeTypeB_,
MaxTransposeTransferSrcScalarPerVector,
MaxTransposeTransferDstScalarPerVector>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
static constexpr ck::index_t kNDimSpatial = NDimSpatial;
using InLayout = InLayout_;
using WeiLayout = WeiLayout_;
using OutLayout = OutLayout_;
using InDataType = InDataType_;
using WeiDataType = WeiDataType_;
using OutDataType = OutDataType_;
using AccDataType = AccDataType_;
using InElementwiseOperation = InElementwiseOperation_;
using WeiElementwiseOperation = WeiElementwiseOperation_;
using OutElementwiseOperation = OutElementwiseOperation_;
static constexpr auto kConvBackwardWeightSpecialization = ConvBackwardWeightSpecialization;
static constexpr ck::index_t kBlockSize = BlockSize;
static constexpr ck::index_t kMPerBlock = MPerBlock;
static constexpr ck::index_t kNPerBlock = NPerBlock;
static constexpr ck::index_t kK0PerBlock = K0PerBlock;
static constexpr ck::index_t kK1 = K1;
static constexpr ck::index_t kMPerXDL = MPerXDL;
static constexpr ck::index_t kNPerXDL = NPerXDL;
static constexpr ck::index_t kMXdlPerWave = MXdlPerWave;
static constexpr ck::index_t kNXdlPerWave = NXdlPerWave;
using ABlockTransferThreadClusterLengths_K0_M_K1 = ABlockTransferThreadClusterLengths_K0_M_K1_;
using ABlockTransferThreadClusterArrangeOrder = ABlockTransferThreadClusterArrangeOrder_;
using ABlockTransferSrcAccessOrder = ABlockTransferSrcAccessOrder_;
static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim;
static constexpr ck::index_t kABlockTransferSrcScalarPerVector =
ABlockTransferSrcScalarPerVector;
static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 =
ABlockTransferDstScalarPerVector_K1;
static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM;
using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_;
using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_;
using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_;
static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim;
static constexpr ck::index_t kBBlockTransferSrcScalarPerVector =
BBlockTransferSrcScalarPerVector;
static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 =
BBlockTransferDstScalarPerVector_K1;
static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN;
static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle;
static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle;
using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_;
static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl =
CBlockTransferScalarPerVector_NWaveNPerXdl;
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr ck::index_t kMaxTransposeTransferSrcScalarPerVector =
MaxTransposeTransferSrcScalarPerVector;
static constexpr ck::index_t kMaxTransposeTransferDstScalarPerVector =
MaxTransposeTransferDstScalarPerVector;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
// Template parameters in exact order
oss << "<" << kNDimSpatial; // 1. NDimSpatial
oss << "," << detail::layout_name<InLayout>(); // 2. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 3. WeiLayout
oss << "," << detail::layout_name<OutLayout>(); // 4. OutLayout
oss << "," << detail::type_name<InDataType>(); // 5. InDataType
oss << "," << detail::type_name<WeiDataType>(); // 6. WeiDataType
oss << "," << detail::type_name<OutDataType>(); // 7. OutDataType
oss << "," << detail::type_name<AccDataType>(); // 8. AccDataType
oss << ","
<< detail::elementwise_op_name<InElementwiseOperation>(); // 9. InElementwiseOperation
oss << ","
<< detail::elementwise_op_name<WeiElementwiseOperation>(); // 10.
// WeiElementwiseOperation
oss << ","
<< detail::elementwise_op_name<OutElementwiseOperation>(); // 11.
// OutElementwiseOperation
oss << ","
<< detail::conv_bwd_weight_spec_name(
kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization
oss << "," << kBlockSize; // 13. BlockSize
oss << "," << kMPerBlock; // 14. MPerBlock
oss << "," << kNPerBlock; // 15. NPerBlock
oss << "," << kK0PerBlock; // 16. K0PerBlock
oss << "," << kK1; // 17. K1
oss << "," << kMPerXDL; // 18. MPerXDL
oss << "," << kNPerXDL; // 19. NPerXDL
oss << "," << kMXdlPerWave; // 20. MXdlPerWave
oss << "," << kNXdlPerWave; // 21. NXdlPerWave
oss << "," << detail::sequence_name<ABlockTransferThreadClusterLengths_K0_M_K1>(); // 22.
oss << "," << detail::sequence_name<ABlockTransferThreadClusterArrangeOrder>(); // 23.
oss << "," << detail::sequence_name<ABlockTransferSrcAccessOrder>(); // 24.
oss << "," << kABlockTransferSrcVectorDim; // 25.
oss << "," << kABlockTransferSrcScalarPerVector; // 26.
oss << "," << kABlockTransferDstScalarPerVector_K1; // 27.
oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterLengths_K0_N_K1>(); // 29.
oss << "," << detail::sequence_name<BBlockTransferThreadClusterArrangeOrder>(); // 30.
oss << "," << detail::sequence_name<BBlockTransferSrcAccessOrder>(); // 31.
oss << "," << kBBlockTransferSrcVectorDim; // 32.
oss << "," << kBBlockTransferSrcScalarPerVector; // 33.
oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34.
oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35.
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36.
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37.
oss << ","
<< detail::sequence_name<
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>(); // 38.
oss << "," << kCBlockTransferScalarPerVector_NWaveNPerXdl; // 39.
oss << "," << detail::type_name<ComputeTypeA>(); // 40.
oss << "," << detail::type_name<ComputeTypeB>(); // 41.
oss << "," << kMaxTransposeTransferSrcScalarPerVector; // 42.
oss << "," << kMaxTransposeTransferDstScalarPerVector; // 43.
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
//

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
//

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
//
@@ -15,6 +15,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
//
@@ -14,6 +14,7 @@
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
// This file will be included by the device implementation header, so we cannot include

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// InstanceTraits specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
//

View File

@@ -0,0 +1,140 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// InstanceTraits specialization for GroupedConvolutionForwardKernel
//
// CRITICAL MAINTENANCE NOTE:
// This InstanceTraits file MUST be kept strictly in sync with the device implementation header:
// ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
// "In sync" means that the template parameter order, names, and types in the declaration below
// MUST EXACTLY MATCH those in the device implementation. If these diverge, you may encounter
// compilation errors, subtle template instantiation mismatches, or silent runtime bugs that are
// difficult to diagnose. Always update both files together and review changes carefully.
#pragma once
#include "instance_traits.hpp"
#include "instance_traits_util.hpp"
// Forward declaration to avoid circular dependency.
namespace ck_tile::device {
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel;
} // namespace ck_tile::device
namespace ck_tile {
namespace reflect {
// Specialization for GroupedConvolutionForwardKernel
template <typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct InstanceTraits<ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>>
{
// CK Tile Conv Traits
// Spatial dimension
static constexpr int kSpatialDim = GroupedConvTraitsType_::NDimSpatial;
// Specialization
static constexpr ck_tile::ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
// DataType types
using InLayout = typename GroupedConvTraitsType_::InLayout;
using WeiLayout = typename GroupedConvTraitsType_::WeiLayout;
using DsLayout = typename GroupedConvTraitsType_::DsLayout;
using OutLayout = typename GroupedConvTraitsType_::OutLayout;
// Vector size
static constexpr int kVectorSizeA = GroupedConvTraitsType_::VectorSizeA;
static constexpr int kVectorSizeB = GroupedConvTraitsType_::VectorSizeB;
static constexpr int kVectorSizeC = GroupedConvTraitsType_::VectorSizeC;
// Num Groups To Merge
static constexpr int kNumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// Split image (large tensors)
static constexpr bool kEnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
// TilePartitioner
// Block configuration
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
// Data types
using ADataType = typename GemmPipeline_::ADataType;
using BDataType = typename GemmPipeline_::BDataType;
// Gemm Pipeline
using GemmPipeline = GemmPipeline_;
static constexpr ck_tile::GemmPipelineScheduler kPipelineScheduler = GemmPipeline_::Scheduler;
static constexpr bool kDoubleSmemBuffer = GemmPipeline_::DoubleSmemBuffer;
static constexpr int kNumWaveGroups = GemmPipeline_::NumWaveGroups;
// Epilogue Pipeline
using AccDataType = typename EpiloguePipeline_::AccDataType;
using EDataType = typename EpiloguePipeline_::ODataType;
using DsDataType = typename EpiloguePipeline_::DsDataType;
using CDEElementwiseOperation = typename EpiloguePipeline_::CDElementwise;
// Static member function to generate instance string
static std::string instance_string()
{
std::ostringstream oss;
// Kernel type name
oss << "GroupedConvolutionForwardKernel";
// Template parameters in exact order matching InstanceTraits member order
oss << "<" << kSpatialDim; // 1. NDimSpatial
oss << ","
<< ck_tile::getConvSpecializationString(ConvSpecialization); // 2. ConvSpecialization
oss << "," << detail::layout_name<InLayout>(); // 3. InLayout
oss << "," << detail::layout_name<WeiLayout>(); // 4. WeiLayout
oss << "," << detail::tuple_name<DsLayout>(); // 5. DsLayout
oss << "," << detail::layout_name<OutLayout>(); // 6. OutLayout
oss << "," << kVectorSizeA; // 7. VectorSizeA
oss << "," << kVectorSizeB; // 8. VectorSizeB
oss << "," << kVectorSizeC; // 9. VectorSizeC
oss << "," << kNumGroupsToMerge; // 10. NumGroupsToMerge
oss << "," << kEnableSplitImage; // 11. EnableSplitImage
oss << "," << kMPerBlock; // 12. MPerBlock
oss << "," << kNPerBlock; // 13. NPerBlock
oss << "," << kKPerBlock; // 14. KPerBlock
oss << "," << kMWarp; // 15. MWarp
oss << "," << kNWarp; // 16. NWarp
oss << "," << kKWarp; // 17. KWarp
oss << "," << kMWarpTile; // 18. MWarpTile
oss << "," << kNWarpTile; // 19. NWarpTile
oss << "," << kKWarpTile; // 20. KWarpTile
oss << "," << detail::type_name<ADataType>(); // 21. ADataType
oss << "," << detail::type_name<BDataType>(); // 22. BDataType
oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer
oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched
oss << "," << kDoubleSmemBuffer; // 25. NumWaveGroups
oss << "," << kNumWaveGroups; // 26. NumWaveGroups
oss << "," << detail::type_name<AccDataType>(); // 27. AccDataType
oss << "," << detail::type_name<EDataType>(); // 28. EDataType
oss << "," << detail::tuple_name<DsDataType>(); // 29. DsDataType
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 30.
// CDEElementwiseOperation
oss << ">";
return oss.str();
}
};
} // namespace reflect
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Utility functions and helpers for instance_traits.hpp
// Contains helper functions to convert types, enums, and sequences to string representations.
@@ -9,9 +9,14 @@
#include <array>
#include <string>
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <limits.h>
#include <cmath>
#include <ostream>
#include <iostream>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
@@ -21,7 +26,12 @@
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck_tile/ops/gemm.hpp>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
namespace ck_tile::reflect::detail {
@@ -32,7 +42,7 @@ namespace impl {
template <typename T>
consteval std::string_view type_name_impl()
{
if constexpr(std::is_same_v<T, ck::half_t>)
if constexpr(std::is_same_v<T, ck::half_t> || std::is_same_v<T, ck_tile::half_t>)
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
@@ -44,11 +54,11 @@ consteval std::string_view type_name_impl()
return "s8";
else if constexpr(std::is_same_v<T, int32_t>)
return "s32";
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
else if constexpr(std::is_same_v<T, ck::bhalf_t> || std::is_same_v<T, ck_tile::bf16_t>)
return "bf16";
else if constexpr(std::is_same_v<T, ck::f8_t>)
else if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck_tile::fp8_t>)
return "fp8";
else if constexpr(std::is_same_v<T, ck::bf8_t>)
else if constexpr(std::is_same_v<T, ck::bf8_t> || std::is_same_v<T, ck_tile::bf8_t>)
return "bf8";
else
return std::string_view{}; // Return empty for supported types
@@ -112,6 +122,20 @@ conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecializatio
}
}
// Convert ConvolutionBackwardWeightSpecialization enum to string
constexpr std::string_view conv_bwd_weight_spec_name(
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization spec)
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
switch(spec)
{
case Default: return "Default";
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case Filter1x1Pad0: return "Filter1x1Pad0";
case OddC: return "OddC";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{
@@ -148,6 +172,17 @@ constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineSchedule
}
}
constexpr std::string_view pipeline_scheduler_name(ck_tile::GemmPipelineScheduler sched)
{
using enum ck_tile::GemmPipelineScheduler;
switch(sched)
{
case Default: return "Default";
case Intrawave: return "Intrawave";
case Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
@@ -186,6 +221,26 @@ constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
}
}
// Convert TailNumber enum to string
constexpr std::string_view tail_number_name(ck_tile::TailNumber tail_num)
{
using enum ck_tile::TailNumber;
switch(tail_num)
{
case Odd: return "Odd";
case Even: return "Even";
case One: return "One";
case Two: return "Two";
case Three: return "Three";
case Four: return "Four";
case Five: return "Five";
case Six: return "Six";
case Seven: return "Seven";
case Empty: return "Empty";
case Full: return "Full";
}
}
// Convert std::array to string
template <typename T, std::size_t N>
inline std::string array_to_string(const std::array<T, N>& arr)
@@ -336,17 +391,53 @@ constexpr std::string tuple_name()
}(static_cast<T*>(nullptr));
}
template <typename T>
requires requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string tuple_name()
{
return []<typename... Ts>(ck_tile::tuple<Ts...>*) constexpr {
if constexpr(sizeof...(Ts) == 0)
{
return std::string("EmptyTuple");
}
else if constexpr((IsLayoutType<Ts> && ...))
{
// Lambda wrapper for layout_name
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
return detail::build_list_string<decltype(layout_name_fn), Ts...>("tuple",
layout_name_fn);
}
else if constexpr((IsDataType<Ts> && ...))
{
// Lambda wrapper for type_name
auto type_name_fn = []<typename U>() { return type_name<U>(); };
return detail::build_list_string<decltype(type_name_fn), Ts...>("tuple", type_name_fn);
}
else
{
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
"tuple elements must be all layouts or all data types, not mixed");
return std::string{}; // unreachable
}
}(static_cast<T*>(nullptr));
}
// Concept to check if a type is a ck::Tuple
template <typename T>
concept IsCkTuple =
requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
// Concept to check if a type is a ck_tile::tuple
template <typename T>
concept IsCkTileTuple =
requires { []<typename... Ts>(ck_tile::tuple<Ts...>*) {}(static_cast<T*>(nullptr)); };
// Deduces whether to use tuple_name or type_name
// Handles both scalar data types and ck::Tuple types
template <typename T>
constexpr std::string type_or_type_tuple_name()
{
if constexpr(IsCkTuple<T>)
if constexpr(IsCkTuple<T> || IsCkTileTuple<T>)
{
return tuple_name<T>();
}
@@ -356,4 +447,30 @@ constexpr std::string type_or_type_tuple_name()
}
}
/// @brief Makes a case insensitive comparison of two string views.
/// @param a First string view
/// @param b Second string view
/// @return Whether two string views a equal case insensitive
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
{
if(a.size() != b.size())
return false;
for(size_t i = 0; i < a.size(); ++i)
{
char c1 = a[i];
char c2 = b[i];
// Convert to lowercase for comparison
if(c1 >= 'A' && c1 <= 'Z')
c1 += 32;
if(c2 >= 'A' && c2 <= 'Z')
c2 += 32;
if(c1 != c2)
return false;
}
return true;
}
} // namespace ck_tile::reflect::detail

View File

@@ -0,0 +1,106 @@
#pragma once
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
namespace ck_tile::reflect {
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
//
// Example Usage:
//
// TreeFormatter f;
// f.writeLine(0, "Root");
// f.writeLine(1, "Branch 1");
// f.writeLine(2, "Item 1a");
// f.writeLast(2, "Item 1b");
// f.writeLast(1, "Branch 2");
// f.writeLast(2, "Item 2a");
// std::cout << f.getString() << "\n";
//
// Generated Output:
//
// Root
// ├─ Branch 1
// │ ├─ Item 1a
// │ └─ Item 1b
// └─ Branch 2
// └─ Item 2a
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
}
// Write the last line at the specified indentation level (branch ends)
template <typename... Args>
void writeLast(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
}
// Get the formatted string (removes trailing newline if present)
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
{
result.pop_back();
}
return result;
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with direct stream insertion
((oss_ << std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};
} // namespace ck_tile::reflect

View File

@@ -1,8 +1,12 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string_view>
#include <variant>
namespace ck_tile::builder {
enum class DataType
@@ -128,29 +132,14 @@ enum class ElementwiseOperation
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
// Enums for pipeline versions & schedulers
enum class PipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
enum struct BlockGemmPipelineScheduler
{
INTRAWAVE,
INTERWAVE,
};
// Enums for the gridwise GEMM pipeline versions.
enum class GridwiseGemmPipelineVersion
{
V1,
V2,
V3, // Only used in stream-K implementation
V4,
V5,
WEIGHT_ONLY
};
@@ -186,10 +175,319 @@ enum class ConvFwdSpecialization
FILTER_3x3
};
enum class LoopScheduler
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
DEFAULT,
M_PADDING,
N_PADDING,
K_PADDING,
MN_PADDING,
MK_PADDING,
NK_PADDING,
MNK_PADDING,
O_PADDING,
MO_PADDING,
NO_PADDING,
KO_PADDING,
MNO_PADDING,
MKO_PADDING,
NKO_PADDING,
MNKO_PADDING,
};
enum class PipelineScheduler
{
DEFAULT,
INTRAWAVE,
INTERWAVE
};
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt)
{
using enum DataType;
switch(dt)
{
case FP16: return os << "FP16";
case FP32: return os << "FP32";
case BF16: return os << "BF16";
case FP8: return os << "FP8";
case I8: return os << "I8";
case U8: return os << "U8";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
{
case FORWARD: return os << "Forward";
case BACKWARD_DATA: return os << "Backward Data";
case BACKWARD_WEIGHT: return os << "Backward Weight";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
{
using enum GroupConvLayout1D;
switch(layout)
{
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
{
using enum GroupConvLayout2D;
switch(layout)
{
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
{
using enum GroupConvLayout3D;
switch(layout)
{
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
{
using enum FwdGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
{
using enum BwdDataGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
{
using enum BwdWeightGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
{
case BIAS: return os << "BIAS";
case BIAS_CLAMP: return os << "BIAS_CLAMP";
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
case BILINEAR: return os << "BILINEAR";
case CLAMP: return os << "CLAMP";
case SCALE: return os << "SCALE";
case PASS_THROUGH: return os << "PASS_THROUGH";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
{
case V1: return os << "V1";
case V2: return os << "V2";
case V3: return os << "V3";
case V4: return os << "V4";
case V5: return os << "V5";
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
{
case Default: return os << "Default";
case MPadding: return os << "MPadding";
case NPadding: return os << "NPadding";
case KPadding: return os << "KPadding";
case MNPadding: return os << "MNPadding";
case MKPadding: return os << "MKPadding";
case NKPadding: return os << "NKPadding";
case MNKPadding: return os << "MNKPadding";
case OPadding: return os << "OPadding";
case MOPadding: return os << "MOPadding";
case NOPadding: return os << "NOPadding";
case KOPadding: return os << "KOPadding";
case MNOPadding: return os << "MNOPadding";
case MKOPadding: return os << "MKOPadding";
case NKOPadding: return os << "NKOPadding";
case MNKOPadding: return os << "MNKOPadding";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
using enum ConvFwdSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return os << "FILTER_3x3";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case ODD_C: return os << "ODD_C";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
{
case DEFAULT: return os << "DEFAULT";
case M_PADDING: return os << "M_PADDING";
case N_PADDING: return os << "N_PADDING";
case K_PADDING: return os << "K_PADDING";
case MN_PADDING: return os << "MN_PADDING";
case MK_PADDING: return os << "MK_PADDING";
case NK_PADDING: return os << "NK_PADDING";
case MNK_PADDING: return os << "MNK_PADDING";
case O_PADDING: return os << "O_PADDING";
case MO_PADDING: return os << "MO_PADDING";
case NO_PADDING: return os << "NO_PADDING";
case KO_PADDING: return os << "KO_PADDING";
case MNO_PADDING: return os << "MNO_PADDING";
case MKO_PADDING: return os << "MKO_PADDING";
case NKO_PADDING: return os << "NKO_PADDING";
case MNKO_PADDING: return os << "MNKO_PADDING";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
{
case DEFAULT: return os << "DEFAULT";
case INTRAWAVE: return os << "INTRAWAVE";
case INTERWAVE: return os << "INTERWAVE";
default: return os << "Unknown";
}
}
// ostream operator overload for std::variant of layout types
inline std::ostream&
operator<<(std::ostream& os,
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
{
std::visit([&os](const auto& l) { os << l; }, layout);
return os;
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
}
} // namespace ck_tile::builder

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>

View File

@@ -20,6 +20,7 @@ endfunction()
add_ck_builder_test(test_ckb_conv_builder
test_conv_builder.cpp
test_fwd_instance_traits.cpp
test_bwd_weight_instance_traits.cpp
test_instance_traits_util.cpp)
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
@@ -30,7 +31,8 @@ add_ck_builder_test(test_ckb_get_instance_string
test_get_instance_string_fwd_grp_conv.cpp
test_get_instance_string_fwd_grp_conv_large_tensor.cpp
test_get_instance_string_fwd_grp_conv_wmma.cpp
test_get_instance_string_fwd_grp_conv_dl.cpp)
test_get_instance_string_fwd_grp_conv_dl.cpp
test_get_instance_string_bwd_weight_grp_conv_xdl.cpp)
# Testing the fwd convolution builder requires kernel compilation.
# To enable parallel compilation, the individual tests are split into separate files.
@@ -41,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)
@@ -62,6 +66,12 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
add_ck_builder_test(test_conv_traits
conv/test_conv_traits.cpp)
add_ck_builder_test(test_conv_description
test_conv_description.cpp)
# Function to add all test_ckb targets to a list
function(collect_test_ckb_targets result_var)
# Get all targets in current directory

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -24,7 +27,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V2,
PipelineVersion::V2,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -22,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
@@ -44,7 +47,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
PipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -0,0 +1,69 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -22,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -22,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -0,0 +1,53 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(
FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -22,7 +25,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
PipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -23,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
PipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
@@ -23,7 +26,7 @@ TEST(FwdConvInstances,
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
PipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,316 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <concepts>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::testing::ElementsAre;
// Test fixture for ConvTraits tests
class ConvTraitsTest : public ::testing::Test
{
};
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::PipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
{
// Define a concrete instance type with specific template parameters
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::
Default, // ConvForwardSpecialization
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
}
} // anonymous namespace

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -49,14 +49,14 @@ struct GridwiseWmmaGemm
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
GridwiseGemmPipelineVersion pipeline_version;
PipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
{
BlockGemmPipelineVersion pipeline_version;
BlockGemmPipelineScheduler scheduler;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
@@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
@@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
LoopScheduler loop_scheduler;
PipelineScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
@@ -214,4 +214,84 @@ static_assert(
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
// DL-specific descriptors
struct DlThreadConfig
{
size_t k0_per_block;
size_t k1;
size_t m1_per_thread;
size_t n1_per_thread;
size_t k_per_thread;
};
static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
struct DlThreadCluster
{
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
struct DlBlockTransferK0M0M1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
struct DlBlockTransferK0N0N1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
struct DlCThreadTransfer
{
std::array<size_t, 6> src_dst_access_order;
size_t src_dst_vector_dim;
size_t dst_scalar_per_vector;
};
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
{
ThreadBlock thread_block;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
DlThreadConfig dl_thread_config;
DlThreadCluster dl_thread_cluster;
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
DlCThreadTransfer dl_c_thread_transfer;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
} // namespace ck_tile::builder::test

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -0,0 +1,112 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck/ck.hpp>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp>
namespace {
TEST(InstanceTraits, BwdWeightXdlCShuffleInstanceStringReturnsCorrectFormat)
{
using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // InLayout
ck::tensor_layout::convolution::GKYXC, // WeiLayout
ck::tensor_layout::convolution::GNHWK, // OutLayout
ck::half_t, // InDataType
ck::half_t, // WeiDataType
ck::half_t, // OutDataType
float, // AccDataType
ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::
Default, // ConvBackwardWeightSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1,
32,
1,
8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CBlockTransferScalarPerVector_NWaveNPerXdl
ck::half_t, // ComputeTypeA
ck::half_t, // ComputeTypeB
1, // MaxTransposeTransferSrcScalarPerVector
1>; // MaxTransposeTransferDstScalarPerVector
std::string instance_str = ck_tile::reflect::instance_string<DeviceInstance>();
std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
"<2" // NDimSpatial
",GNHWC" // InLayout
",GKYXC" // WeiLayout
",GNHWK" // OutLayout
",fp16" // InDataType
",fp16" // WeiDataType
",fp16" // OutDataType
",fp32" // AccDataType
",PassThrough" // InElementwiseOperation
",PassThrough" // WeiElementwiseOperation
",PassThrough" // OutElementwiseOperation
",Default" // ConvBackwardWeightSpecialization
",256" // BlockSize
",128" // MPerBlock
",128" // NPerBlock
",4" // K0PerBlock
",8" // K1
",32" // MPerXDL
",32" // NPerXDL
",2" // MXdlPerWave
",2" // NXdlPerWave
",Seq(4,64,1)" // ABlockTransferThreadClusterLengths_K0_M_K1
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",8" // ABlockTransferSrcScalarPerVector
",8" // ABlockTransferDstScalarPerVector_K1
",false" // ABlockLdsAddExtraM
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths_K0_N_K1
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",8" // BBlockTransferSrcScalarPerVector
",8" // BBlockTransferDstScalarPerVector_K1
",false" // BBlockLdsAddExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,32,1,8)" // CBlockTransferClusterLengths
",8" // CBlockTransferScalarPerVector_NWaveNPerXdl
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",1" // MaxTransposeTransferSrcScalarPerVector
",1>"; // MaxTransposeTransferDstScalarPerVector
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp>
#include "ck/utility/data_type.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp>
#include "ck/utility/data_type.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dynamic_op.hpp>
#include "ck/utility/data_type.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp>

View File

@@ -1,3 +1,6 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
class ConvBuilderTest : public ::testing::Test

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckr = ck_tile::reflect::conv;
namespace ckt = ck_tile::test;
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
ckb::GroupConvDeviceOp device_operation =
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct DefaultAlgorithm
{
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
ckb::test::BlockTransferABC block_transfer{
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {.order = {0, 1, 2}},
.block_transfer_access_order_b = {.order = {0, 1, 2}},
.src_access_order_a = {.order = {0, 1, 2}},
.src_access_order_b = {.order = {0, 1, 2}}};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"├─ Algorithm\n"
"│ ├─ Thread block size: 256\n"
"│ ├─ Data tile size: 256×256×32\n"
"│ ├─ Gemm padding: DEFAULT\n"
"│ ├─ Convolution specialization: DEFAULT\n"
"│ ├─ Pipeline version: V4\n"
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
"│ ├─ Warp Gemm parameters: \n"
"│ │ ├─ subtile size: 16×16\n"
"│ │ └─ Number of warp gemm iterations: 4×4\n"
"│ ├─ Memory access:\n"
"│ │ ├─ A Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ ├─ B Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ └─ C Tile transfer: \n"
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
"│ └─ \n"
"└─ "));
}
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
// does not have a specialization for backward data convolutions. The test fails with:
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
//
// To enable this test, a ConvFactory specialization for backward data operations must be
// implemented first.
//
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
// {
// struct BackwardDataSignature
// {
// int spatial_dim = 2;
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
// ckb::DataType data_type = ckb::DataType::FP16;
// ckb::ElementwiseOperation elementwise_operation =
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
// };
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
//
// static constexpr const BackwardDataSignature SIGNATURE;
// static constexpr const DefaultAlgorithm ALGORITHM;
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
//
// // Verify Brief works
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
// ckt::StringEqWithDiff("2D Backward Data convolution"));
//
// // Verify detailed works - to be updated once ConvFactory is implemented
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
// ckt::StringEqWithDiff("PLACEHOLDER"));
// }
} // namespace

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
@@ -11,6 +11,7 @@
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp>
#include <ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp>
namespace {
@@ -720,4 +721,126 @@ TEST(InstanceTraits, DlInstanceStringReturnsCorrectFormat)
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
{
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
ck_tile::tuple<> /*DsLayout*/,
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
4 /*VectorSizeA*/,
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
false /*DoubleSmemBuffer*/,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
1 /*NumWaveGroups*/>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ck_tile::bf16_t /*InDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
float /*AccDataType*/,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
true /*has_hot_loop_v*/,
ck_tile::TailNumber::Full /*tail_number_v*/,
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
ck_tile::bf16_t /*OutDataType*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*InDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
ck_tile::tuple<> /*DsDataType*/,
float /*AccDataType*/,
ck_tile::bf16_t /*OutDataType*/,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
ck_tile::element_wise::PassThrough /*CDElementWise*/,
128 /*MPerBlock*/,
128 /*NPerBlock*/,
4 /*M_Warp*/,
1 /*N_Warp*/,
16 /*M_Warp_Tile*/,
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
ck_tile::memory_operation_enum::set /*memory_operation*/,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvFwdKernel =
ck_tile::device::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvFwdKernel>();
std::string expected_str = "GroupedConvolutionForwardKernel"
"<2" // NDimSpatial
",Default" // ConvSpecialization
",NHWGC" // InLayout
",GKYXC" // WeiLayout
",EmptyTuple" // DsLayout
",NHWGK" // OutLayout
",4" // VectorSizeA
",4" // VectorSizeB
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock
",4" // MWarp
",1" // NWarp
",1" // KWarp
",16" // MWarpTile
",16" // NWarpTile
",16" // KWarpTile
",bf16" // ADataType
",bf16" // BDataType
",COMPUTE_V3" // BlkGemmPipelineVer
",Intrawave" // BlkGemmPipeSched
",0" // DoubleSmemBuffer
",1" // NumWaveGroups
",fp32" // AccDataType
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
">";
EXPECT_EQ(instance_str, expected_str);
}
} // anonymous namespace

View File

@@ -0,0 +1,86 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>
#include <ck/tensor_operation/gpu/device/device_base.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp>
// Test GetInstanceString through base class pointer for backward weight XDL variant
TEST(GetInstanceString, ReturnsStringForBwdWeightGrpConvXdlInstance)
{
// Use the template helper to get a working instance configuration
using InstanceTuple = ck::tensor_operation::device::instance::
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
2, // NDimSpatial
ck::tensor_operation::device::instance::GNHWC, // InLayout
ck::tensor_operation::device::instance::GKYXC, // WeiLayout
ck::tensor_operation::device::instance::GNHWK, // OutLayout
ck::tensor_operation::device::instance::
ConvBwdWeightDefault>; // ConvBwdWeightSpecialization
// Get the first instance from the tuple
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
// Define the base class type using the most general operator base
using BaseClass = ck::tensor_operation::device::BaseOperator;
// Create an instance of the derived class
DeviceInstance device_instance;
// Get a pointer to the base class
BaseClass* base_ptr = &device_instance;
// Call GetInstanceString through the base class pointer
std::string instance_str = base_ptr->GetInstanceString();
// Expected complete instance string based on the first instance from
// device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances
// This corresponds to the configuration with BlockSize=64, MPerBlock=64, NPerBlock=64, etc.
std::string expected_str = "DeviceGroupedConvBwdWeight_Xdl_CShuffle"
"<2" // NDimSpatial
",GNHWC" // InLayout
",GKYXC" // WeiLayout
",GNHWK" // OutLayout
",fp16" // InDataType
",fp16" // WeiDataType
",fp16" // OutDataType
",fp32" // AccDataType
",PassThrough" // InElementwiseOperation
",PassThrough" // WeiElementwiseOperation
",PassThrough" // OutElementwiseOperation
",Default" // ConvBackwardWeightSpecialization
",64" // BlockSize
",64" // MPerBlock
",64" // NPerBlock
",4" // K0PerBlock
",8" // K1
",32" // MPerXDL
",32" // NPerXDL
",2" // MXdlPerWave
",2" // NXdlPerWave
",Seq(1,4,8,2)" // ABlockTransferThreadClusterLengths_K0_M_K1
",Seq(0,3,1,2)" // ABlockTransferThreadClusterArrangeOrder
",Seq(0,2,1,3)" // ABlockTransferSrcAccessOrder
",2" // ABlockTransferSrcVectorDim
",2" // ABlockTransferSrcScalarPerVector
",4" // ABlockTransferDstScalarPerVector_K1
",true" // ABlockLdsAddExtraM
",Seq(1,4,8,2)" // BBlockTransferThreadClusterLengths_K0_N_K1
",Seq(0,3,1,2)" // BBlockTransferThreadClusterArrangeOrder
",Seq(0,2,1,3)" // BBlockTransferSrcAccessOrder
",2" // BBlockTransferSrcVectorDim
",2" // BBlockTransferSrcScalarPerVector
",4" // BBlockTransferDstScalarPerVector_K1
",true" // BBlockLdsAddExtraN
",1" // CShuffleMXdlPerWavePerShuffle
",1" // CShuffleNXdlPerWavePerShuffle
",Seq(1,16,1,4)" // CBlockTransferClusterLengths
",2" // CBlockTransferScalarPerVector_NWaveNPerXdl
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",1" // MaxTransposeTransferSrcScalarPerVector
",1>"; // MaxTransposeTransferDstScalarPerVector
EXPECT_EQ(instance_str, expected_str);
}

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <ck_tile/builder/reflect/instance_traits.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "testing_utils.hpp"
#include <gtest/gtest.h>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck/library/tensor_operation_instance/device_operation_instance_factory.hpp>
#include <gtest/gtest.h>

View File

@@ -1,5 +1,5 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,7 +16,7 @@ using namespace test;
// Common test implementation
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
PipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
{
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
.src_access_order_b = {1, 0, 2}};
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
if(FwdPipelineVersion == PipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
else if(FwdPipelineVersion == PipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
else if(FwdPipelineVersion == PipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
else if(FwdPipelineVersion == PipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 2,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = GridwiseGemmPipelineVersion::V1};
.pipeline_version = PipelineVersion::V1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
.loop_scheduler = PipelineScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
EXPECT_NE(invoker_ptr, nullptr);
}
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
{
// DL thread configuration
constexpr DlThreadConfig DlThreadCfg{
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
// DL thread cluster
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
// DL A block transfer - K0_M0_M1_K1 format
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL B block transfer - K0_N0_N1_K1 format
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL C thread transfer
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4};
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.dl_thread_config = DlThreadCfg,
.dl_thread_cluster = DlCluster,
.dl_block_transfer_a = DlBlockTransferA,
.dl_block_transfer_b = DlBlockTransferB,
.dl_c_thread_transfer = DlCTransfer};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
{
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
// Large_Tensor uses the same descriptor as regular XDL CShuffle
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.gridwise_gemm = FwdGemmParams,
.block_transfer = FwdBlockTransfer,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
} // namespace ck_tile::builder::test_utils

View File

@@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
constexpr auto BlockGemmPipeline_Selector()
{
@@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
@@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else

View File

@@ -30,6 +30,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_base
{
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr index_t B_KRow = 1;
#endif
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
ComputeTypeB,
AccDataType,
MPerWmma,
NPerWmma,
KPack / KInner,
TransposeC>{};
static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
static constexpr auto wmma_gemm =
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
static constexpr index_t KRepeat = KPerBlock / KPack;
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
}
template <index_t m0, index_t n0>
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I0,
I1));
// C[M, N, NumRegWmma]
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;

View File

@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v1
{
@@ -55,6 +56,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -114,10 +118,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
using Base::WaveSize;
using typename Base::HotLoopInstList;
using Base::A_K1;
@@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
a_thread_buf);
if constexpr(m0 == I0)
{
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
@@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
}
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, I0, I0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
I0,
I0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
I0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
}
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
static_for<0, KInner, 1>{}([&](auto) {
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
});
});
});
@@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
protected:
// A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, Number<A_K1>{}));
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
// B[NRepeat, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;
@@ -399,6 +418,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BlockSize,
@@ -419,6 +439,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -438,6 +459,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -458,6 +480,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
@@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
m0,
I0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0_inner, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
b_thread_buf);
});
}
@@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
b_thread_buf);
});
}
@@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<ik / A_K1>{},
m0,
k0_inner,
I0,
I0,
Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0_inner,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0_inner,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
m0 == MRepeat - 1 && n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<ik / B_K1>{},
n0,
k0_inner,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
@@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number<KRepeatPerCluster>{},
I1,
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
@@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number<KRepeatPerCluster>{},
I1,
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I0,
I1));
using AThreadCopy =
@@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;

View File

@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v3
{
@@ -55,6 +56,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -114,6 +118,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
@@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same_v<BScaleStruct, Empty>)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
}
});
@@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<ik / A_K1>{},
m0,
k0,
I0,
I0,
Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<ik / B_K1>{},
n0,
k0,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});

View File

@@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
});
});
HotLoopScheduler();
if constexpr(MPerBlock >= 64)
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};

Some files were not shown because too many files have changed in this diff Show More