mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Merge branch 'develop' into moe_xcd_remap
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -66,6 +66,12 @@ docs/doxygen/xml
|
||||
cmake-build*/
|
||||
build*/
|
||||
|
||||
# LSP configuration
|
||||
.clangd
|
||||
|
||||
# User-defined CMake presets
|
||||
CMakeUserPresets.json
|
||||
|
||||
# Python virtualenv
|
||||
.venv/
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 64, KPerBlock,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
2, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t N = 7168;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t tokens = 208;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
|
||||
@@ -14,19 +14,11 @@
|
||||
|
||||
struct ConvConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
@@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdData,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdData,
|
||||
typename GroupedConvTraitsType::CLayoutBwdData,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
|
||||
@@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
|
||||
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
|
||||
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
@@ -184,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GemmPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
@@ -235,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
NumGroupsToMerge>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutFwd,
|
||||
typename GroupedConvTraitsType::BsLayoutFwd,
|
||||
typename GroupedConvTraitsType::CLayoutFwd,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
@@ -185,8 +180,9 @@ struct GroupedConvolutionForwardInvoker
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
|
||||
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
1, /*NumGroupsToMerge*/
|
||||
false /*EnableSplitImage*/>;
|
||||
using GroupedConvTraitsTypeDefault =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using GroupedConvTraitsTypeLargeTensor =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
@@ -64,23 +54,28 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
1, /*NumGroupsToMerge*/
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge,
|
||||
true /*EnableSplitImage*/>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
|
||||
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
|
||||
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
|
||||
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeA,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
|
||||
using TransformType =
|
||||
ck_tile::TransformConvFwdToGemm<NDimSpatial,
|
||||
ck_tile::ConvolutionSpecialization::Default,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeA,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeB,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeC,
|
||||
1, // NumGroupsToMerge
|
||||
false, // SplitN
|
||||
InDataType,
|
||||
@@ -264,21 +260,21 @@ struct GroupedConvolutionForwardInvoker
|
||||
GroupedConvTraitsTypeLargeTensor,
|
||||
GroupedConvTraitsTypeDefault>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
// Use split-image kernel if layout supports it, otherwise use regular kernel
|
||||
@@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) {
|
||||
// Concept for parameter that describe block GEMM problem.
|
||||
template <typename T>
|
||||
concept BlockGemmDescriptor = requires(T t) {
|
||||
{ t.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise WMMA GEMM problem.
|
||||
@@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.n_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.pipeline_version } -> std::convertible_to<GridwiseGemmPipelineVersion>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
@@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<BlockGemmPipelineScheduler>;
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -180,7 +180,90 @@ concept SpecifiesNumGroupsToMerge = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<LoopScheduler>;
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
|
||||
// Concept for DL thread configuration
|
||||
template <typename T>
|
||||
concept DlThreadConfigDescriptor = requires(T t) {
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.n1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k_per_thread } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for DL thread cluster
|
||||
template <typename T>
|
||||
concept DlThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_M0_M1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_N0_N1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL C thread transfer
|
||||
template <typename T>
|
||||
concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadConfig = requires {
|
||||
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread cluster
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadCluster = requires {
|
||||
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL A block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferA = requires {
|
||||
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL B block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferB = requires {
|
||||
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlCThreadTransfer = requires {
|
||||
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -36,9 +36,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
|
||||
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
|
||||
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
|
||||
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
|
||||
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
|
||||
#pragma push_macro("GridwiseGemmTemplateParameters")
|
||||
#ifdef GridwiseGemmTemplateParameters
|
||||
#undef GridwiseGemmTemplateParameters
|
||||
#endif
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#pragma pop_macro("GridwiseGemmTemplateParameters")
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
@@ -297,42 +309,42 @@ constexpr BlockGemmSpec SetBlockGemm()
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
|
||||
if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
}
|
||||
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
|
||||
else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineScheduler");
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
}
|
||||
|
||||
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
|
||||
if constexpr(BG.pipeline_version == PipelineVersion::V1)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
|
||||
else if constexpr(BG.pipeline_version == PipelineVersion::V5)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -442,17 +454,17 @@ consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
|
||||
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
|
||||
if constexpr(loop_scheduler == PipelineScheduler::DEFAULT)
|
||||
{
|
||||
return ck::LoopScheduler::Default;
|
||||
}
|
||||
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
|
||||
else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE)
|
||||
{
|
||||
return ck::LoopScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown LoopScheduler");
|
||||
static_assert(false, "Unknown PipelineScheduler");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,29 +472,29 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
|
||||
if constexpr(pipeline_version == PipelineVersion::V1)
|
||||
{
|
||||
return ck::PipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::PipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V3)
|
||||
{
|
||||
static_assert(false, "V3 is used only for stream-K.");
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
|
||||
else if constexpr(pipeline_version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::PipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
|
||||
else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY)
|
||||
{
|
||||
return ck::PipelineVersion::weight_only;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -566,29 +578,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
if constexpr(version == PipelineVersion::V1)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V2)
|
||||
else if constexpr(version == PipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
else if constexpr(version == PipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
else if constexpr(version == PipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
else if constexpr(version == PipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
static_assert(false, "Unknown PipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel using Direct Load (DL) approach.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
|
||||
"DL algorithm must specify thread config.");
|
||||
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
|
||||
"DL algorithm must specify thread cluster.");
|
||||
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
|
||||
"DL algorithm must specify A block transfer.");
|
||||
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
|
||||
"DL algorithm must specify B block transfer.");
|
||||
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
|
||||
"DL algorithm must specify C thread transfer.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel with large tensor support (N-splitting).
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Enumeration for CK Device Operation types.
|
||||
// This allows the builder to select which device operation template to instantiate
|
||||
// based on the user's requirements.
|
||||
enum class DeviceOpType
|
||||
{
|
||||
// Forward Convolution - Non-grouped
|
||||
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
|
||||
|
||||
// Forward Convolution - Grouped
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/tree_formatter.hpp>
|
||||
|
||||
/// @file conv_description.hpp
|
||||
/// @brief Provides human-readable descriptions of ConvBuilder configurations
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
|
||||
layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
};
|
||||
|
||||
// Algorithm information - groups all algorithm-related configuration
|
||||
struct GemmAlgorithmInfo
|
||||
{
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
WarpGemmParams warp_gemm;
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
std::variant<builder::ConvFwdSpecialization,
|
||||
builder::ConvBwdDataSpecialization,
|
||||
builder::ConvBwdWeightSpecialization>
|
||||
conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
// Provides human-readable descriptions of ConvBuilder configurations.
|
||||
struct ConvDescription
|
||||
{
|
||||
ConvSignatureInfo signature;
|
||||
GemmAlgorithmInfo algorithm;
|
||||
|
||||
// Brief one-line summary
|
||||
std::string brief() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Detailed hierarchical description
|
||||
std::string detailed() const
|
||||
{
|
||||
TreeFormatter f;
|
||||
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature.data_type);
|
||||
f.writeLine(2, "Memory Layout: ", signature.layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
|
||||
|
||||
f.writeLine(1, "Algorithm");
|
||||
// Compute Block section
|
||||
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
|
||||
f.writeLine(2,
|
||||
"Data tile size: ",
|
||||
algorithm.tile_dims.m,
|
||||
"×",
|
||||
algorithm.tile_dims.n,
|
||||
"×",
|
||||
algorithm.tile_dims.k);
|
||||
f.writeLine(2, "Gemm padding: ", algorithm.padding);
|
||||
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
|
||||
// Pipeline section
|
||||
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
|
||||
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
|
||||
f.writeLine(2, "Warp Gemm parameters: ");
|
||||
f.writeLine(
|
||||
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
|
||||
f.writeLast(3,
|
||||
"Number of warp gemm iterations: ",
|
||||
algorithm.warp_gemm.m_iter,
|
||||
"×",
|
||||
algorithm.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLine(2, "Memory access:");
|
||||
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLine(3, "B Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLast(3, "C Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Data shuffle (number of gemm instructions per iteration): ",
|
||||
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
|
||||
"×",
|
||||
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution used to store data: ",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[0],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[1],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[2],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLast(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
algorithm.c_tile_transfer.scalar_per_vector);
|
||||
f.writeLast(2);
|
||||
f.writeLast(1);
|
||||
return f.getString();
|
||||
}
|
||||
|
||||
// Educational explanation of optimization choices
|
||||
std::string explain() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Performance characteristics and use case guidance
|
||||
std::string suggest() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// Helper concept to detect if a type has InstanceTraits specialization
|
||||
template <typename T>
|
||||
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
|
||||
|
||||
// Helper concept to detect ConvBuilder types
|
||||
template <typename T>
|
||||
concept IsConvBuilder = requires {
|
||||
typename T::Factory;
|
||||
typename T::Instance;
|
||||
};
|
||||
|
||||
// Primary factory function: Create ConvDescription from Instance type directly
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance>
|
||||
ConvDescription Describe()
|
||||
{
|
||||
using Traits = ConvTraits<Instance>;
|
||||
|
||||
return ConvDescription{
|
||||
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.layout = Traits::layout,
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op},
|
||||
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding}};
|
||||
}
|
||||
|
||||
// Backward compatibility: Create ConvDescription from Builder type
|
||||
template <typename Builder>
|
||||
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
|
||||
ConvDescription Describe()
|
||||
{
|
||||
// Delegate to Instance-based version
|
||||
using Instance = typename Builder::Instance;
|
||||
return Describe<Instance>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -0,0 +1,719 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/conv_factory.hpp>
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits.hpp>
|
||||
#include <ck_tile/builder/types.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/utility/loop_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp>
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
// Helper metafunctions to convert from ck enums to builder enums
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5).
|
||||
/// @details This function maps CK's block GEMM pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. The pipeline version
|
||||
/// determines the strategy used for data movement and computation overlap in the
|
||||
/// GEMM kernel's main loop.
|
||||
template <ck::BlockGemmPipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v3)
|
||||
return V3;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == v5)
|
||||
return V5;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum.
|
||||
/// @tparam ck_ver The CK PipelineVersion enum value to convert.
|
||||
/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY).
|
||||
/// @details This function maps CK's general pipeline version identifiers to the
|
||||
/// builder framework's standardized pipeline version enum. Note that this overload
|
||||
/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion
|
||||
/// variant, including support for specialized weight-only pipelines.
|
||||
template <ck::PipelineVersion ck_ver>
|
||||
constexpr auto convert_pipeline_version()
|
||||
{
|
||||
using enum ck::PipelineVersion;
|
||||
using enum builder::PipelineVersion;
|
||||
if constexpr(ck_ver == v1)
|
||||
return V1;
|
||||
else if constexpr(ck_ver == v2)
|
||||
return V2;
|
||||
else if constexpr(ck_ver == v4)
|
||||
return V4;
|
||||
else if constexpr(ck_ver == weight_only)
|
||||
return WEIGHT_ONLY;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE).
|
||||
/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the
|
||||
/// builder framework's standardized scheduler enum. The scheduler determines how work
|
||||
/// is distributed and synchronized within and across wavefronts during pipeline execution.
|
||||
/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates
|
||||
/// across multiple wavefronts.
|
||||
template <ck::BlockGemmPipelineScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Intrawave)
|
||||
return INTRAWAVE;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
}
|
||||
|
||||
/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum.
|
||||
/// @tparam ck_sched The CK LoopScheduler enum value to convert.
|
||||
/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE).
|
||||
/// @details This function maps CK's loop scheduler identifiers to the builder framework's
|
||||
/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of
|
||||
/// the main computational loop are scheduled across threads. DEFAULT uses the standard
|
||||
/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved
|
||||
/// performance in certain scenarios.
|
||||
template <ck::LoopScheduler ck_sched>
|
||||
constexpr auto convert_pipeline_scheduler()
|
||||
{
|
||||
using enum ck::LoopScheduler;
|
||||
using enum builder::PipelineScheduler;
|
||||
if constexpr(ck_sched == Default)
|
||||
return DEFAULT;
|
||||
else if constexpr(ck_sched == Interwave)
|
||||
return INTERWAVE;
|
||||
}
|
||||
|
||||
/// @brief Helper structures for organizing trait data with domain-specific naming
|
||||
|
||||
/// @brief Data tile dimensions processed by a workgroup.
|
||||
/// @details This struct defines the M, N, and K dimensions of the data tile
|
||||
/// that a single workgroup (thread block) is responsible for processing in the
|
||||
/// underlying GEMM computation.
|
||||
struct DataTileInfo
|
||||
{
|
||||
int m; ///< M dimension of the tile processed by the workgroup (MPerBlock).
|
||||
int n; ///< N dimension of the tile processed by the workgroup (NPerBlock).
|
||||
int k; ///< K dimension of the tile processed by the workgroup (KPerBlock).
|
||||
};
|
||||
|
||||
/// @brief Dimensions for an input data tile transfer.
|
||||
/// @details Defines the shape of the input tile (A or B matrix) as it is
|
||||
/// transferred from global memory to LDS. The tile is conceptually divided
|
||||
/// into k0 and k1 dimensions.
|
||||
struct InputTileTransferDimensions
|
||||
{
|
||||
int k0; ///< The outer dimension of K, where K = k0 * k1.
|
||||
int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix.
|
||||
int k1; ///< The inner dimension of K, often corresponding to the vector load size from global
|
||||
///< memory.
|
||||
};
|
||||
|
||||
/// @brief Parameters governing the transfer of an input tile.
|
||||
/// @details This struct holds configuration details for how an input tile is
|
||||
/// loaded from global memory into LDS, including thread clustering, memory
|
||||
/// access patterns, and vectorization settings.
|
||||
struct InputTileTransferParams
|
||||
{
|
||||
int k1; ///< The inner K dimension size, often matching the vectorization width.
|
||||
std::array<int, 3>
|
||||
thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how
|
||||
///< many threads are arranged on each axis.
|
||||
std::array<int, 3> thread_cluster_order; ///< The order of thread spatial distribution over the
|
||||
///< input tensor dimensions.
|
||||
std::array<int, 3> src_access_order; ///< The order of accessing input tensor axes (e.g., which
|
||||
///< dimension to read first).
|
||||
int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed
|
||||
///< (the contiguous dimension).
|
||||
int src_scalar_per_vector; ///< The size of the vector access instruction; the number of
|
||||
///< elements accessed per thread per instruction.
|
||||
int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1
|
||||
///< dimension.
|
||||
bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank
|
||||
///< conflicts.
|
||||
};
|
||||
|
||||
/// @brief Complete information for an input tile transfer.
|
||||
/// @details Combines the dimensional information and transfer parameters for
|
||||
/// a full description of an input tile's journey from global memory to LDS.
|
||||
struct InputTileTransferInfo
|
||||
{
|
||||
InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile.
|
||||
InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation.
|
||||
};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
/// @details Defines the configuration of the GEMM operation performed by each
|
||||
/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions.
|
||||
struct WarpGemmParams
|
||||
{
|
||||
int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl).
|
||||
int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl).
|
||||
int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per
|
||||
///< wavefront (MXdlPerWave).
|
||||
int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per
|
||||
///< wavefront (NXdlPerWave).
|
||||
};
|
||||
|
||||
/// @brief Parameters for shuffling data between warps (CShuffle optimization).
|
||||
/// @details Configures how many MFMA instruction results are processed per
|
||||
/// wave in each iteration of the CShuffle routine.
|
||||
struct WarpShuffleParams
|
||||
{
|
||||
int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave
|
||||
///< per shuffle iteration.
|
||||
};
|
||||
|
||||
/// @brief Information for the output tile transfer (CShuffle).
|
||||
/// @details Describes how the final computed tile (C matrix) is written out from
|
||||
/// LDS to global memory, including shuffling, thread clustering, and vectorization.
|
||||
struct OutputTileTransferInfo
|
||||
{
|
||||
WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling.
|
||||
// m_block, m_wave_per_xdl, n_block, n_wave_per_xdl
|
||||
std::array<int, 4> thread_cluster_dims; ///< The spatial thread distribution used for storing
|
||||
///< data into the output tensor.
|
||||
int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the
|
||||
///< output tensor.
|
||||
};
|
||||
|
||||
// Helper metafunctions to derive signature information from Instance types
|
||||
|
||||
/// @brief Derives the convolution direction from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT).
|
||||
template <typename Instance>
|
||||
constexpr builder::ConvDirection conv_direction()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { &InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::FORWARD;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_DATA;
|
||||
}
|
||||
else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
return builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
return builder::ConvDirection::FORWARD; // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
|
||||
/// `builder::ConvBwdWeightSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvForwardSpecialization == Default)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3)
|
||||
{
|
||||
return builder::ConvFwdSpecialization::FILTER_3x3;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdDataSpecialization == Default)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
}
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
|
||||
if constexpr(InstTraits::kConvBwdWeightSpecialization == Default)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::DEFAULT;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0;
|
||||
}
|
||||
else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC)
|
||||
{
|
||||
return builder::ConvBwdWeightSpecialization::ODD_C;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the grouped convolution layout from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_layout()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ALayout = typename InstTraits::ALayout;
|
||||
using BLayout = typename InstTraits::BLayout;
|
||||
using ELayout = typename InstTraits::ELayout;
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if constexpr(InstTraits::kSpatialDim == 1)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNWC> && std::is_same_v<BLayout, ctc::GKXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNWK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::GNWC_GKXC_GNWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NWGC_GKXC_NWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKXC> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKXC_NGKW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCX> && std::is_same_v<ELayout, ctc::NGKW>)
|
||||
{
|
||||
return builder::GroupConvLayout1D::NGCW_GKCX_NGKW;
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNHWC> && std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKHW>)
|
||||
{
|
||||
return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW;
|
||||
}
|
||||
}
|
||||
else if constexpr(InstTraits::kSpatialDim == 3)
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ctc::GNDHWC> && std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::GNDHWK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NDHWGC> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKZYXC> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, ctc::NGCDHW> &&
|
||||
std::is_same_v<BLayout, ctc::GKCZYX> &&
|
||||
std::is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the data type from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32).
|
||||
template <typename Instance>
|
||||
constexpr builder::DataType conv_data_type()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using ADataType = typename InstTraits::ADataType;
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck::half_t>)
|
||||
{
|
||||
return builder::DataType::FP16;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, ck::bhalf_t>)
|
||||
{
|
||||
return builder::DataType::BF16;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, float>)
|
||||
{
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, ck::f8_t>)
|
||||
{
|
||||
return builder::DataType::FP8;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, int8_t>)
|
||||
{
|
||||
return builder::DataType::I8;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, uint8_t>)
|
||||
{
|
||||
return builder::DataType::U8;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default fallback
|
||||
return builder::DataType::FP32;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives the elementwise operation from op type.
|
||||
/// @tparam ElementwiseOp Elementwise operation functor type.
|
||||
/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation.
|
||||
template <typename ElementwiseOp>
|
||||
constexpr builder::ElementwiseOperation elementwise_op()
|
||||
{
|
||||
constexpr std::string_view name = detail::elementwise_op_name<ElementwiseOp>();
|
||||
if constexpr(detail::case_insensitive_equal(name, "Bias"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BIAS_BNORM_CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Bilinear"))
|
||||
{
|
||||
return builder::ElementwiseOperation::BILINEAR;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Clamp"))
|
||||
{
|
||||
return builder::ElementwiseOperation::CLAMP;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "Scale"))
|
||||
{
|
||||
return builder::ElementwiseOperation::SCALE;
|
||||
}
|
||||
else if constexpr(detail::case_insensitive_equal(name, "PassThrough"))
|
||||
{
|
||||
return builder::ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Derives a gemm padding from a kernel instance type.
|
||||
/// @tparam Instance - A Device Kernel object type.
|
||||
/// @return A `builder::GemmPadding` enum value corresponding to kernel padding.
|
||||
template <typename Instance>
|
||||
constexpr builder::GemmPadding gemm_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::GemmPadding;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
constexpr auto gemm_spec = InstTraits::kGemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == Default)
|
||||
{
|
||||
return DEFAULT;
|
||||
}
|
||||
else if constexpr(gemm_spec == MPadding)
|
||||
{
|
||||
return M_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NPadding)
|
||||
{
|
||||
return N_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KPadding)
|
||||
{
|
||||
return K_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNPadding)
|
||||
{
|
||||
return MN_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKPadding)
|
||||
{
|
||||
return MK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKPadding)
|
||||
{
|
||||
return NK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKPadding)
|
||||
{
|
||||
return MNK_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == OPadding)
|
||||
{
|
||||
return O_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MOPadding)
|
||||
{
|
||||
return MO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NOPadding)
|
||||
{
|
||||
return NO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == KOPadding)
|
||||
{
|
||||
return KO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNOPadding)
|
||||
{
|
||||
return MNO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MKOPadding)
|
||||
{
|
||||
return MKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == NKOPadding)
|
||||
{
|
||||
return NKO_PADDING;
|
||||
}
|
||||
else if constexpr(gemm_spec == MNKOPadding)
|
||||
{
|
||||
return MNKO_PADDING;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Primary template for extracting convolution traits.
|
||||
/// @details This struct is the main entry point for reflecting on a convolution
|
||||
/// kernel's properties. It is specialized to handle different kinds of input types.
|
||||
template <typename T>
|
||||
struct ConvTraits;
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`.
|
||||
/// @details This is the primary specialization used to extract a comprehensive
|
||||
/// set of traits directly from a fully-formed device kernel `Instance` type.
|
||||
/// It uses `InstanceTraits` to access the kernel's template parameters.
|
||||
template <typename Instance>
|
||||
requires requires { typename InstanceTraits<Instance>; }
|
||||
struct ConvTraits<Instance>
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
|
||||
// --- Signature Information ---
|
||||
/// @brief The number of spatial dimensions in the convolution (1, 2, or 3).
|
||||
static constexpr int spatial_dim = InstTraits::kSpatialDim;
|
||||
/// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight).
|
||||
static constexpr builder::ConvDirection direction = conv_direction<Instance>();
|
||||
/// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK).
|
||||
static constexpr auto layout = conv_layout<Instance>();
|
||||
/// @brief The primary data type used in the computation (e.g., FP16, FP32).
|
||||
static constexpr builder::DataType data_type = conv_data_type<Instance>();
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
elementwise_op<typename InstTraits::AElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
elementwise_op<typename InstTraits::BElementwiseOperation>();
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
elementwise_op<typename InstTraits::CDEElementwiseOperation>();
|
||||
|
||||
/// @brief The GEMM specialization used by the kernel - padding
|
||||
static constexpr auto gemm_padding = gemm_spec<Instance>();
|
||||
/// @brief The convolution-specific specialization (e.g., Default, 1x1).
|
||||
static constexpr auto conv_specialization = conv_spec<Instance>();
|
||||
|
||||
// --- Algorithm Information ---
|
||||
/// @brief The total number of threads in a thread block (workgroup).
|
||||
static constexpr int thread_block_size = InstTraits::kBlockSize;
|
||||
/// @brief The dimensions of the data tile processed by the thread block.
|
||||
static constexpr DataTileInfo tile_dims = {
|
||||
.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock};
|
||||
|
||||
/// @brief Configuration for the A-matrix (input) tile transfer.
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1,
|
||||
.m_or_n = InstTraits::kMPerBlock,
|
||||
.k1 = InstTraits::kAK1},
|
||||
.transfer_params = {.k1 = InstTraits::kAK1,
|
||||
.thread_cluster_dims = InstTraits::kAThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kABlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kABlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kABlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kABlockLdsExtraM)}};
|
||||
|
||||
/// @brief Configuration for the B-matrix (weights) tile transfer.
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = {
|
||||
.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1,
|
||||
.m_or_n = InstTraits::kNPerBlock,
|
||||
.k1 = InstTraits::kBK1},
|
||||
.transfer_params = {.k1 = InstTraits::kBK1,
|
||||
.thread_cluster_dims = InstTraits::kBThreadClusterLengths,
|
||||
.thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder,
|
||||
.src_access_order = InstTraits::kBBlockTransferSrcAccessOrder,
|
||||
.src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim,
|
||||
.src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector,
|
||||
.dst_scalar_per_vector_k1 =
|
||||
InstTraits::kBBlockTransferDstScalarPerVectorK1,
|
||||
.lds_padding = static_cast<bool>(InstTraits::kBBlockLdsExtraN)}};
|
||||
|
||||
/// @brief Parameters for the warp-level GEMM computation.
|
||||
static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL,
|
||||
.gemm_n = InstTraits::kNPerXDL,
|
||||
.m_iter = InstTraits::kMXdlPerWave,
|
||||
.n_iter = InstTraits::kNXdlPerWave};
|
||||
|
||||
/// @brief Configuration for the C-matrix (output) tile transfer.
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = {
|
||||
.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle,
|
||||
.n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle},
|
||||
.thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0],
|
||||
InstTraits::kCThreadClusterLengths[1],
|
||||
InstTraits::kCThreadClusterLengths[2],
|
||||
InstTraits::kCThreadClusterLengths[3]},
|
||||
.scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector};
|
||||
|
||||
/// @brief Helper to safely get the pipeline version.
|
||||
/// @details This is only available for some convolutions (e.g., forward).
|
||||
/// If not present in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_version()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineVersion; })
|
||||
{
|
||||
return convert_pipeline_version<T::kPipelineVersion>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineVersion::V1;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The block GEMM pipeline version used by the kernel.
|
||||
static constexpr auto pipeline_version = get_pipeline_version();
|
||||
|
||||
/// @brief Helper to safely get the pipeline scheduler.
|
||||
/// @details This is only available for some convolutions. If not present
|
||||
/// in `InstanceTraits`, it returns a default value.
|
||||
template <typename T = InstTraits>
|
||||
static constexpr auto get_pipeline_scheduler()
|
||||
{
|
||||
if constexpr(requires { T::kPipelineScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kPipelineScheduler>();
|
||||
}
|
||||
else if constexpr(requires { T::kLoopScheduler; })
|
||||
{
|
||||
return convert_pipeline_scheduler<T::kLoopScheduler>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a default or indicate not available
|
||||
return builder::PipelineScheduler::DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The pipeline scheduler used by the kernel.
|
||||
static constexpr auto pipeline_scheduler = get_pipeline_scheduler();
|
||||
};
|
||||
|
||||
/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type.
|
||||
/// @details This specialization provides backward compatibility for reflecting
|
||||
/// on kernels defined via the `ConvBuilder` interface. It works by first
|
||||
/// creating the `Instance` via the builder's factory, and then delegating
|
||||
/// all trait extraction to the `ConvTraits<Instance>` specialization.
|
||||
template <builder::ConvSignatureDescriptor auto SIGNATURE,
|
||||
builder::ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
builder::StringLiteral VERSION>
|
||||
struct ConvTraits<builder::ConvBuilder<SIGNATURE, ALGORITHM, VERSION>>
|
||||
{
|
||||
using Factory = builder::ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
using Instance = typename Factory::Instance;
|
||||
|
||||
// Delegate to Instance-based ConvTraits
|
||||
using InstanceConvTraits = ConvTraits<Instance>;
|
||||
|
||||
// Forward all members from Instance-based traits
|
||||
static constexpr int spatial_dim = InstanceConvTraits::spatial_dim;
|
||||
static constexpr builder::ConvDirection direction = InstanceConvTraits::direction;
|
||||
static constexpr auto layout = InstanceConvTraits::layout;
|
||||
static constexpr builder::DataType data_type = InstanceConvTraits::data_type;
|
||||
|
||||
static constexpr builder::ElementwiseOperation input_element_op =
|
||||
InstanceConvTraits::input_element_op;
|
||||
static constexpr builder::ElementwiseOperation weight_element_op =
|
||||
InstanceConvTraits::weight_element_op;
|
||||
static constexpr builder::ElementwiseOperation output_element_op =
|
||||
InstanceConvTraits::output_element_op;
|
||||
|
||||
static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding;
|
||||
static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization;
|
||||
|
||||
static constexpr int thread_block_size = InstanceConvTraits::thread_block_size;
|
||||
static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims;
|
||||
static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer;
|
||||
static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer;
|
||||
static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm;
|
||||
static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer;
|
||||
static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version;
|
||||
static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -14,18 +14,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include "instance_traits_util.hpp"
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
|
||||
// Forward declaration to avoid circular dependency.
|
||||
// This file will be included by the device implementation header, so we cannot include
|
||||
|
||||
@@ -9,9 +9,14 @@
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <limits.h>
|
||||
#include <cmath>
|
||||
#include <ostream>
|
||||
#include <iostream>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
@@ -371,4 +376,30 @@ constexpr std::string type_or_type_tuple_name()
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Makes a case insensitive comparison of two string views.
|
||||
/// @param a First string view
|
||||
/// @param b Second string view
|
||||
/// @return Whether two string views a equal case insensitive
|
||||
constexpr bool case_insensitive_equal(std::string_view a, std::string_view b)
|
||||
{
|
||||
if(a.size() != b.size())
|
||||
return false;
|
||||
|
||||
for(size_t i = 0; i < a.size(); ++i)
|
||||
{
|
||||
char c1 = a[i];
|
||||
char c2 = b[i];
|
||||
|
||||
// Convert to lowercase for comparison
|
||||
if(c1 >= 'A' && c1 <= 'Z')
|
||||
c1 += 32;
|
||||
if(c2 >= 'A' && c2 <= 'Z')
|
||||
c2 += 32;
|
||||
|
||||
if(c1 != c2)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Helper class for formatting hierarchical tree structures with proper indentation
|
||||
// and tree-drawing characters (├─, └─, │, etc.)
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// TreeFormatter f;
|
||||
// f.writeLine(0, "Root");
|
||||
// f.writeLine(1, "Branch 1");
|
||||
// f.writeLine(2, "Item 1a");
|
||||
// f.writeLast(2, "Item 1b");
|
||||
// f.writeLast(1, "Branch 2");
|
||||
// f.writeLast(2, "Item 2a");
|
||||
// std::cout << f.getString() << "\n";
|
||||
//
|
||||
// Generated Output:
|
||||
//
|
||||
// Root
|
||||
// ├─ Branch 1
|
||||
// │ ├─ Item 1a
|
||||
// │ └─ Item 1b
|
||||
// └─ Branch 2
|
||||
// └─ Item 2a
|
||||
class TreeFormatter
|
||||
{
|
||||
public:
|
||||
TreeFormatter() = default;
|
||||
|
||||
// Write a line at the specified indentation level (branch continues after this)
|
||||
template <typename... Args>
|
||||
void writeLine(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Write the last line at the specified indentation level (branch ends)
|
||||
template <typename... Args>
|
||||
void writeLast(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Get the formatted string (removes trailing newline if present)
|
||||
std::string getString() const
|
||||
{
|
||||
std::string result = oss_.str();
|
||||
if(!result.empty() && result.back() == '\n')
|
||||
{
|
||||
result.pop_back();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream oss_;
|
||||
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
|
||||
|
||||
// Implementation of line writing with tree symbols
|
||||
template <typename... Args>
|
||||
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
|
||||
{
|
||||
// Ensure we have enough tracking space
|
||||
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
|
||||
{
|
||||
is_last_at_level_.resize(indent_level + 1, false);
|
||||
// Level 0 (root) should always be treated as "last" since it has no tree symbols
|
||||
if(is_last_at_level_.size() > 0)
|
||||
{
|
||||
is_last_at_level_[0] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Draw the tree structure
|
||||
// Start from level 1 (skip level 0 which is the root with no symbols)
|
||||
for(int i = 1; i < indent_level; ++i)
|
||||
{
|
||||
// For all parent levels, draw vertical line or space based on whether they ended
|
||||
oss_ << (is_last_at_level_[i] ? " " : "│ ");
|
||||
}
|
||||
|
||||
// Draw the branch symbol for the current level
|
||||
if(indent_level > 0)
|
||||
{
|
||||
oss_ << (is_last ? "└─ " : "├─ ");
|
||||
}
|
||||
|
||||
// Write the content using fold expression with direct stream insertion
|
||||
((oss_ << std::forward<Args>(args)), ...);
|
||||
|
||||
oss_ << '\n';
|
||||
|
||||
// Update tracking for this level AFTER writing the line
|
||||
// This ensures future lines at deeper levels know if this level ended
|
||||
is_last_at_level_[indent_level] = is_last;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -3,6 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
@@ -128,29 +132,14 @@ enum class ElementwiseOperation
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
// Enums for pipeline versions & schedulers
|
||||
enum class PipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
INTRAWAVE,
|
||||
INTERWAVE,
|
||||
};
|
||||
|
||||
// Enums for the gridwise GEMM pipeline versions.
|
||||
enum class GridwiseGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3, // Only used in stream-K implementation
|
||||
V4,
|
||||
V5,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -186,10 +175,319 @@ enum class ConvFwdSpecialization
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
enum class LoopScheduler
|
||||
// Enums for the backward data convolution specialization.
|
||||
enum class ConvBwdDataSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
};
|
||||
|
||||
// Enums for the backward weight convolution specialization.
|
||||
enum class ConvBwdWeightSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_1X1_PAD0,
|
||||
ODD_C,
|
||||
};
|
||||
|
||||
// Enums for the Gemm padding.
|
||||
enum class GemmPadding
|
||||
{
|
||||
DEFAULT,
|
||||
M_PADDING,
|
||||
N_PADDING,
|
||||
K_PADDING,
|
||||
MN_PADDING,
|
||||
MK_PADDING,
|
||||
NK_PADDING,
|
||||
MNK_PADDING,
|
||||
O_PADDING,
|
||||
MO_PADDING,
|
||||
NO_PADDING,
|
||||
KO_PADDING,
|
||||
MNO_PADDING,
|
||||
MKO_PADDING,
|
||||
NKO_PADDING,
|
||||
MNKO_PADDING,
|
||||
};
|
||||
|
||||
enum class PipelineScheduler
|
||||
{
|
||||
DEFAULT,
|
||||
INTRAWAVE,
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return os << "FP16";
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
{
|
||||
case FORWARD: return os << "Forward";
|
||||
case BACKWARD_DATA: return os << "Backward Data";
|
||||
case BACKWARD_WEIGHT: return os << "Backward Weight";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
|
||||
{
|
||||
using enum GroupConvLayout1D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
|
||||
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
|
||||
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
|
||||
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
|
||||
{
|
||||
using enum GroupConvLayout2D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
|
||||
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
|
||||
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
|
||||
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
{
|
||||
using enum GroupConvLayout3D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
|
||||
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
|
||||
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
|
||||
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum FwdGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
|
||||
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdDataGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdWeightGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
|
||||
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
|
||||
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case BIAS: return os << "BIAS";
|
||||
case BIAS_CLAMP: return os << "BIAS_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return os << "BILINEAR";
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case V1: return os << "V1";
|
||||
case V2: return os << "V2";
|
||||
case V3: return os << "V3";
|
||||
case V4: return os << "V4";
|
||||
case V5: return os << "V5";
|
||||
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return os << "Default";
|
||||
case MPadding: return os << "MPadding";
|
||||
case NPadding: return os << "NPadding";
|
||||
case KPadding: return os << "KPadding";
|
||||
case MNPadding: return os << "MNPadding";
|
||||
case MKPadding: return os << "MKPadding";
|
||||
case NKPadding: return os << "NKPadding";
|
||||
case MNKPadding: return os << "MNKPadding";
|
||||
case OPadding: return os << "OPadding";
|
||||
case MOPadding: return os << "MOPadding";
|
||||
case NOPadding: return os << "NOPadding";
|
||||
case KOPadding: return os << "KOPadding";
|
||||
case MNOPadding: return os << "MNOPadding";
|
||||
case MKOPadding: return os << "MKOPadding";
|
||||
case NKOPadding: return os << "NKOPadding";
|
||||
case MNKOPadding: return os << "MNKOPadding";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return os << "FILTER_3x3";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case ODD_C: return os << "ODD_C";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case M_PADDING: return os << "M_PADDING";
|
||||
case N_PADDING: return os << "N_PADDING";
|
||||
case K_PADDING: return os << "K_PADDING";
|
||||
case MN_PADDING: return os << "MN_PADDING";
|
||||
case MK_PADDING: return os << "MK_PADDING";
|
||||
case NK_PADDING: return os << "NK_PADDING";
|
||||
case MNK_PADDING: return os << "MNK_PADDING";
|
||||
case O_PADDING: return os << "O_PADDING";
|
||||
case MO_PADDING: return os << "MO_PADDING";
|
||||
case NO_PADDING: return os << "NO_PADDING";
|
||||
case KO_PADDING: return os << "KO_PADDING";
|
||||
case MNO_PADDING: return os << "MNO_PADDING";
|
||||
case MKO_PADDING: return os << "MKO_PADDING";
|
||||
case NKO_PADDING: return os << "NKO_PADDING";
|
||||
case MNKO_PADDING: return os << "MNKO_PADDING";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case INTRAWAVE: return os << "INTRAWAVE";
|
||||
case INTERWAVE: return os << "INTERWAVE";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of layout types
|
||||
inline std::ostream&
|
||||
operator<<(std::ostream& os,
|
||||
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
|
||||
{
|
||||
std::visit([&os](const auto& l) { os << l; }, layout);
|
||||
return os;
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
ConvBwdDataSpecialization,
|
||||
ConvBwdWeightSpecialization>& spec)
|
||||
{
|
||||
std::visit([&os](const auto& s) { os << s; }, spec);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
@@ -64,6 +66,12 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
function(collect_test_ckb_targets result_var)
|
||||
# Get all targets in current directory
|
||||
|
||||
@@ -27,7 +27,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V2,
|
||||
PipelineVersion::V2,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
PipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
@@ -25,7 +25,7 @@ TEST(FwdConvInstances,
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
PipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
PipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ TEST(FwdConvInstances,
|
||||
run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
PipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
|
||||
316
experimental/builder/test/conv/test_conv_traits.cpp
Normal file
316
experimental/builder/test/conv/test_conv_traits.cpp
Normal file
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <concepts>
|
||||
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
|
||||
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// Test fixture for ConvTraits tests
|
||||
class ConvTraitsTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
|
||||
ck::PipelineVersion::v1, // BlkGemmPipelineVer
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
false>; // DirectLoad
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
|
||||
// Verify A tile transfer info
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify B tile transfer info
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
|
||||
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
|
||||
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
|
||||
|
||||
// Verify warp GEMM params
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
|
||||
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
|
||||
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
|
||||
|
||||
// Verify output tile transfer info
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
|
||||
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
|
||||
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
|
||||
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopSched
|
||||
1>; // NumGroupsToMerge
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
}
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
{
|
||||
// Define a concrete instance type with specific template parameters
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
2, // NDimSpatial
|
||||
ck::tensor_layout::convolution::GNHWC, // ALayout
|
||||
ck::tensor_layout::convolution::GKYXC, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
ck::tensor_layout::convolution::GNHWK, // ELayout
|
||||
ck::half_t, // ADataType
|
||||
ck::half_t, // BDataType
|
||||
float, // AccDataType
|
||||
ck::half_t, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
ck::half_t, // EDataType
|
||||
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
|
||||
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::
|
||||
Default, // ConvForwardSpecialization
|
||||
ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
ck::Sequence<1,
|
||||
32,
|
||||
1,
|
||||
8>, // CDEBlockTransferClusterLengths
|
||||
8, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ck::half_t, // AComputeDataType
|
||||
ck::half_t, // BComputeDataType
|
||||
ck::LoopScheduler::Default>; // LoopSched
|
||||
|
||||
// Use ConvTraits to extract compile-time information
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK);
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
// Verify tile dimensions
|
||||
EXPECT_EQ(Traits::tile_dims.m, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.n, 128);
|
||||
EXPECT_EQ(Traits::tile_dims.k, 16);
|
||||
}
|
||||
} // anonymous namespace
|
||||
@@ -49,14 +49,14 @@ struct GridwiseWmmaGemm
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
GridwiseGemmPipelineVersion pipeline_version;
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
|
||||
struct BlockGemm
|
||||
{
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
BlockGemmPipelineScheduler scheduler;
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
|
||||
@@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
LoopScheduler loop_scheduler;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
@@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
LoopScheduler loop_scheduler;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
@@ -214,4 +214,84 @@ static_assert(
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
// DL-specific descriptors
|
||||
struct DlThreadConfig
|
||||
{
|
||||
size_t k0_per_block;
|
||||
size_t k1;
|
||||
size_t m1_per_thread;
|
||||
size_t n1_per_thread;
|
||||
size_t k_per_thread;
|
||||
};
|
||||
static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
|
||||
|
||||
struct DlThreadCluster
|
||||
{
|
||||
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
struct DlBlockTransferK0M0M1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
|
||||
|
||||
struct DlBlockTransferK0N0N1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
|
||||
|
||||
struct DlCThreadTransfer
|
||||
{
|
||||
std::array<size_t, 6> src_dst_access_order;
|
||||
size_t src_dst_vector_dim;
|
||||
size_t dst_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
DlThreadConfig dl_thread_config;
|
||||
DlThreadCluster dl_thread_cluster;
|
||||
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
|
||||
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
|
||||
DlCThreadTransfer dl_c_thread_transfer;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
169
experimental/builder/test/test_conv_description.cpp
Normal file
169
experimental/builder/test/test_conv_description.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckr = ck_tile::reflect::conv;
|
||||
namespace ckt = ck_tile::test;
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
{
|
||||
ckb::test::ThreadBlock thread_block{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
ckb::test::BlockTransferABC block_transfer{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {.order = {0, 1, 2}},
|
||||
.block_transfer_access_order_b = {.order = {0, 1, 2}},
|
||||
.src_access_order_a = {.order = {0, 1, 2}},
|
||||
.src_access_order_b = {.order = {0, 1, 2}}};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"├─ Algorithm\n"
|
||||
"│ ├─ Thread block size: 256\n"
|
||||
"│ ├─ Data tile size: 256×256×32\n"
|
||||
"│ ├─ Gemm padding: DEFAULT\n"
|
||||
"│ ├─ Convolution specialization: DEFAULT\n"
|
||||
"│ ├─ Pipeline version: V4\n"
|
||||
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
"│ ├─ Warp Gemm parameters: \n"
|
||||
"│ │ ├─ subtile size: 16×16\n"
|
||||
"│ │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
"│ ├─ Memory access:\n"
|
||||
"│ │ ├─ A Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ ├─ B Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ └─ C Tile transfer: \n"
|
||||
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
"│ └─ \n"
|
||||
"└─ "));
|
||||
}
|
||||
|
||||
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
|
||||
// does not have a specialization for backward data convolutions. The test fails with:
|
||||
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
|
||||
//
|
||||
// To enable this test, a ConvFactory specialization for backward data operations must be
|
||||
// implemented first.
|
||||
//
|
||||
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
|
||||
// {
|
||||
// struct BackwardDataSignature
|
||||
// {
|
||||
// int spatial_dim = 2;
|
||||
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
|
||||
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
// ckb::DataType data_type = ckb::DataType::FP16;
|
||||
// ckb::ElementwiseOperation elementwise_operation =
|
||||
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
// };
|
||||
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
|
||||
//
|
||||
// static constexpr const BackwardDataSignature SIGNATURE;
|
||||
// static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
//
|
||||
// // Verify Brief works
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
|
||||
// ckt::StringEqWithDiff("2D Backward Data convolution"));
|
||||
//
|
||||
// // Verify detailed works - to be updated once ConvFactory is implemented
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
// ckt::StringEqWithDiff("PLACEHOLDER"));
|
||||
// }
|
||||
} // namespace
|
||||
@@ -16,7 +16,7 @@ using namespace test;
|
||||
// Common test implementation
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
PipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
{
|
||||
@@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion,
|
||||
.scheduler = BlockGemmPipelineScheduler::INTRAWAVE};
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
@@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3()
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
if(FwdPipelineVersion == PipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
else if(FwdPipelineVersion == PipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
@@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle()
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 2,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = GridwiseGemmPipelineVersion::V1};
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
.loop_scheduler = PipelineScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
|
||||
{
|
||||
// DL thread configuration
|
||||
constexpr DlThreadConfig DlThreadCfg{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
// DL thread cluster
|
||||
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
// DL A block transfer - K0_M0_M1_K1 format
|
||||
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL B block transfer - K0_N0_N1_K1 format
|
||||
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL C thread transfer
|
||||
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.dl_thread_config = DlThreadCfg,
|
||||
.dl_thread_cluster = DlCluster,
|
||||
.dl_block_transfer_a = DlBlockTransferA,
|
||||
.dl_block_transfer_b = DlBlockTransferB,
|
||||
.dl_c_thread_transfer = DlCTransfer};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
// Large_Tensor uses the same descriptor as regular XDL CShuffle
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(
|
||||
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
if constexpr(MPerBlock >= 64)
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
|
||||
@@ -0,0 +1,891 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_m3_k;
|
||||
using Base::b_block_desc_n0_n1_n2_n3_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::APackedSize;
|
||||
using Base::BMmaKStride;
|
||||
using Base::BPackedSize;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::KXdlPack;
|
||||
using Base::MXdlPack;
|
||||
using Base::NXdlPack;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple5 = typename Base::Tuple5;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
|
||||
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
|
||||
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
|
||||
static constexpr auto async_vmcnt =
|
||||
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
|
||||
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves +
|
||||
num_buffer_load_a_scale + num_buffer_load_b_scale;
|
||||
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_bufs;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefetch 1, sync the async load
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(scale_mem_buf));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(scale_mem_buf));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(scale_mem_buf));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(
|
||||
scale_comp_buf)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(
|
||||
scale_comp_buf)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
|
||||
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0,
|
||||
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
|
||||
1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(I1));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
|
||||
});
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
constexpr auto num_total_stages = std::max(2, MRepeat);
|
||||
if constexpr(num_total_stages > 2)
|
||||
{
|
||||
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale +
|
||||
num_buffer_load_b_scale;
|
||||
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
|
||||
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
|
||||
|
||||
constexpr auto mfma_perstage_more =
|
||||
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
|
||||
constexpr auto mfma_perstage_less =
|
||||
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
|
||||
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
|
||||
mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -122,7 +122,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
math::max(2, NXdlPerWave_),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
|
||||
@@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(
|
||||
math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
|
||||
@@ -48,28 +48,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
}
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
@@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
|
||||
// NPerBlock>;
|
||||
|
||||
#if 0
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
problem.MPadded,
|
||||
@@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
|
||||
// We pad the M unconditionaly for Scale
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
const index_t expert_id =
|
||||
__builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
|
||||
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr(NSwizzle)
|
||||
{
|
||||
@@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
|
||||
constexpr auto AKThreads = AK0Threads * AK1Threads;
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
|
||||
|
||||
// Gride buffer creation
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
|
||||
// A matrix blockwise direct to LDS copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
1>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
@@ -1463,7 +1463,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared),
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
|
||||
|
||||
// Blockwise GEMM pipeline
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
@@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
|
||||
auto thread_offset_shuffled =
|
||||
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
// get each thread's offset int the scale tensor
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock;
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
@@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
@@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
auto b_blockwise_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave / NXdlPack>{},
|
||||
I1,
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
@@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
// A
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
// Gate and Up
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
// C
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
// A scale
|
||||
a_scale_grid_desc_am_ak,
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
// B scale
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_thread_copy,
|
||||
b_scale_thread_copy_up,
|
||||
@@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_grid_desc_ak0_m_ak1, // A
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_grid_desc_bpreshuffled, // B
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_grid_desc_am_ak,
|
||||
c_thread_buf, // C
|
||||
a_scale_grid_desc_am_ak, // A scale
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_grid_desc_bn_ak, // B scale
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
num_k_block_main_loop);
|
||||
@@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
|
||||
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
|
||||
|
||||
// mul scales
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
|
||||
static_assert(M5 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
|
||||
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
|
||||
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock +
|
||||
m0 * M2 * M1 * M3 * M4 * M5 +
|
||||
m1 * M2 * M3 * M4 * M5 +
|
||||
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
|
||||
topk_weights =
|
||||
*c_style_pointer_cast<const vector_type<float, M5>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
}
|
||||
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation ==
|
||||
Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m5] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M4,
|
||||
M5)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
|
||||
// per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2, // N2 = NXdlPack
|
||||
N3))), // N3 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 2, 4, 6, 7, 8>{},
|
||||
Sequence<>{},
|
||||
Sequence<1, 3, 5, 9>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
@@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
@@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
N2,
|
||||
M3,
|
||||
I1,
|
||||
M5,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
m_thread_data_on_block_idx[I5],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
@@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = 1; // hack fix felix
|
||||
constexpr index_t scatter_weight_idx = 3; // hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
|
||||
// Sequence support
|
||||
// arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
@@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
|
||||
NXdlPerWave / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
@@ -1984,7 +2039,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
});
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
|
||||
129
include/ck_tile/core/arch/arch.hpp
Normal file → Executable file
129
include/ck_tile/core/arch/arch.hpp
Normal file → Executable file
@@ -136,66 +136,103 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
#endif
|
||||
}
|
||||
|
||||
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
|
||||
struct WaitcntLayoutGfx12
|
||||
{ // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0]
|
||||
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem
|
||||
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds
|
||||
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
|
||||
};
|
||||
|
||||
struct WaitcntLayoutGfx11
|
||||
{ // vm[15:10] (6), lgkm[9:4] (6), exp unused
|
||||
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
|
||||
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
|
||||
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
|
||||
};
|
||||
|
||||
struct WaitcntLayoutLegacy
|
||||
{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
|
||||
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
|
||||
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
|
||||
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
|
||||
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c)
|
||||
{
|
||||
c &= VM_MASK;
|
||||
return ((c & 0xF) << 0) | ((c & 0x30) << 10);
|
||||
}
|
||||
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); }
|
||||
};
|
||||
|
||||
// Select active layout
|
||||
#if defined(__gfx12__)
|
||||
using Waitcnt = WaitcntLayoutGfx12;
|
||||
#elif defined(__gfx11__)
|
||||
using Waitcnt = WaitcntLayoutGfx11;
|
||||
#else
|
||||
using Waitcnt = WaitcntLayoutLegacy;
|
||||
#endif
|
||||
|
||||
//----------------------------------------------
|
||||
// Public API: only from_* (constexpr templates)
|
||||
//----------------------------------------------
|
||||
struct waitcnt_arg
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
// use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
|
||||
CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & (cnt << 8);
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_expcnt()
|
||||
{
|
||||
return 0; // no export in MI series
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & cnt;
|
||||
}
|
||||
// kMax* exposed for callers; match field widths per-arch
|
||||
#if defined(__gfx12__) || defined(__gfx11__)
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
|
||||
#else
|
||||
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
|
||||
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
|
||||
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
|
||||
#endif
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
|
||||
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_expcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
|
||||
return MAX & (cnt << 4);
|
||||
static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range");
|
||||
return Waitcnt::pack_vm(cnt);
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
|
||||
{
|
||||
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
|
||||
return MAX & (cnt << 8);
|
||||
static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range");
|
||||
return Waitcnt::pack_lgkm(cnt);
|
||||
}
|
||||
|
||||
template <index_t cnt>
|
||||
CK_TILE_DEVICE static constexpr index_t from_expcnt()
|
||||
{
|
||||
if constexpr(Waitcnt::HAS_EXP)
|
||||
{
|
||||
// EXP_MASK only exists on legacy
|
||||
#if !defined(__gfx12__) && !defined(__gfx11__)
|
||||
static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
|
||||
return Waitcnt::pack_exp(cnt);
|
||||
#else
|
||||
(void)cnt;
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(cnt == 0, "expcnt unsupported on this arch");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
|
||||
|
||||
@@ -102,11 +102,14 @@ struct static_counter_uniq_;
|
||||
}
|
||||
|
||||
#define MAKE_SC() \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
|
||||
#define NEXT_SC(c_) c_.next<__COUNTER__>()
|
||||
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
|
||||
__extension__ ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
__extension__ ck_tile:: \
|
||||
static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> \
|
||||
{ \
|
||||
}
|
||||
#define NEXT_SC(c_) __extension__ c_.next<__COUNTER__>()
|
||||
#define NEXT_SCI(c_, static_i_) __extension__ c_.next<__COUNTER__ + static_i_>()
|
||||
|
||||
// Usage:
|
||||
// constexpr auto c = MAKE_SC()
|
||||
|
||||
@@ -74,6 +74,21 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
// Fixed values for Implicit GEMM
|
||||
struct FixedGemmParams
|
||||
{
|
||||
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool FixedVectorSize = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Persistent = false;
|
||||
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
};
|
||||
// Compile time parameters
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
@@ -82,31 +97,43 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Weight Gemm Layouts
|
||||
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdData =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdData,
|
||||
BsLayoutBwdData,
|
||||
CLayoutBwdData,
|
||||
NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdWeight,
|
||||
BsLayoutBwdWeight,
|
||||
CLayoutBwdWeight,
|
||||
NumWaveGroups>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
|
||||
@@ -74,6 +74,6 @@ class ProfilerOperationRegistry final
|
||||
#define PP_CONCAT(x, y) PP_CONCAT_IMPL(x, y)
|
||||
#define PP_CONCAT_IMPL(x, y) x##y
|
||||
|
||||
#define REGISTER_PROFILER_OPERATION(name, description, operation) \
|
||||
static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \
|
||||
#define REGISTER_PROFILER_OPERATION(name, description, operation) \
|
||||
__extension__ static const bool PP_CONCAT(operation_registration_result_, __COUNTER__) = \
|
||||
::ProfilerOperationRegistry::GetInstance().Add(name, description, operation)
|
||||
|
||||
Reference in New Issue
Block a user