mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Switch into universal gemms for conv bwds (#2981)
* switch into universal gemms for conv bwds * some fixes and support universal gemm in conv fwd * add reviewer comments
This commit is contained in:
303
example/ck_tile/20_grouped_convolution/gemm_configs.hpp
Normal file
303
example/ck_tile/20_grouped_convolution/gemm_configs.hpp
Normal file
@@ -0,0 +1,303 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
|
||||
struct ConvTypeConfig;
|
||||
|
||||
template <>
|
||||
struct ConvTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using InDataType = ck_tile::half_t;
|
||||
using WeiDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using OutDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using InDataType = ck_tile::bf16_t;
|
||||
using WeiDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using OutDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "grouped_convolution_backward_data_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_data_example.inc"
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardDataInvoker;
|
||||
@@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
@@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
{
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -24,121 +24,170 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
else
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "grouped_convolution_backward_weight_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightInvoker;
|
||||
@@ -27,14 +27,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -54,9 +54,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -23,73 +23,120 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
@@ -97,11 +144,11 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -112,34 +159,35 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_backward_weight_two_stage_invoker.hpp"
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
#include "gemm_configs.hpp"
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Invoker = GroupedConvolutionBackwardWeightTwoStageInvoker;
|
||||
@@ -27,14 +28,14 @@ int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
|
||||
GemmWarpConfig,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, arg_parser);
|
||||
}
|
||||
@@ -54,9 +55,9 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
|
||||
return !run_grouped_conv_bwd_weight_example<GemmConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -25,56 +25,103 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 1;
|
||||
constexpr ck_tile::index_t VectorSizeB = 1;
|
||||
constexpr ck_tile::index_t VectorSizeC = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType, // A: Out
|
||||
InDataType, // B: In
|
||||
@@ -86,12 +133,12 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
@@ -99,7 +146,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
const ck_tile::index_t spatial_lengths_accum =
|
||||
@@ -166,14 +213,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
@@ -186,7 +233,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
@@ -199,17 +246,22 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "grouped_convolution_forward_invoker.hpp"
|
||||
#include "run_grouped_convolution_fwd_example.inc"
|
||||
|
||||
template <typename GemmWarpConfig>
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
using Invoker = GroupedConvolutionForwardInvoker;
|
||||
@@ -30,12 +30,16 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker, GemmWarpConfig, ck_tile::half_t>(
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker, GemmWarpConfig, ck_tile::bf16_t>(
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -47,8 +51,8 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_example<GemmWarpConfig_Wmma>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_example<GemmWarpConfig_Mfma>(argc, argv);
|
||||
return !run_grouped_conv_fwd_example<GemmConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
@@ -23,113 +23,171 @@ struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
else
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -11,7 +11,11 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
#include "gemm_configs.hpp"
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>;
|
||||
struct GemmWarpConfig_Mfma
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_data<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -39,7 +39,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
@@ -141,7 +141,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_data<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
@@ -215,7 +215,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -225,7 +225,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -235,7 +235,7 @@ int run_grouped_conv_bwd_data_example_prec_type(
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -17,7 +17,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -31,7 +31,7 @@ float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
@@ -131,7 +131,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
float ave_time = invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -193,7 +193,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(ck_tile::ArgParser& arg_par
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
@@ -217,7 +217,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -227,7 +227,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -237,7 +237,7 @@ int run_grouped_conv_bwd_weight_example_prec_type(std::string in_layout,
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
@@ -17,7 +17,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
@@ -39,7 +39,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
@@ -141,7 +141,7 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_fwd<NDimSpatial,
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -193,7 +193,7 @@ int run_grouped_conv_fwd_example_with_layouts(
|
||||
}
|
||||
|
||||
template <typename Invoker,
|
||||
typename GemmWarpConfig,
|
||||
typename GemmConfig,
|
||||
typename InPrecType,
|
||||
typename WeiPrecType = InPrecType,
|
||||
typename OutPrecType = InPrecType>
|
||||
@@ -215,7 +215,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -225,7 +225,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
@@ -235,7 +235,7 @@ int run_grouped_conv_fwd_example_prec_type(
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
|
||||
GemmWarpConfig,
|
||||
GemmConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
|
||||
Reference in New Issue
Block a user