Merge branch 'develop' into jograner/conv-bwd-weight-kbatch-in-2gb-limit

This commit is contained in:
Graner, Johannes
2025-11-07 12:23:45 +00:00
32 changed files with 3287 additions and 695 deletions

6
.gitignore vendored
View File

@@ -66,6 +66,12 @@ docs/doxygen/xml
cmake-build*/
build*/
# LSP configuration
.clangd
# User-defined CMake presets
CMakeUserPresets.json
# Python virtualenv
.venv/

View File

@@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;
// clang-format off
@@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 2,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t N = 7168;
ck::index_t K = 256;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t tokens = 208;
ck::index_t topk = 2;
if(argc == 1)

View File

@@ -14,19 +14,11 @@
struct ConvConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr ck_tile::index_t VectorSizeA = 4;
static constexpr ck_tile::index_t VectorSizeB = 8;
static constexpr ck_tile::index_t VectorSizeC = 8;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
@@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>

View File

@@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdData,
typename GroupedConvTraitsType::BsLayoutBwdData,
typename GroupedConvTraitsType::CLayoutBwdData,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
InDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
OutDataType,
WeiDataType,
DsDataType,
AccDataType,
InDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
auto preprocess = [&]() {
ck_tile::hip_check_error(hipMemsetAsync(
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{

View File

@@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
AccDataType,
WeiDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -184,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
{
using WorkspaceDataType = float;
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 4;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
InDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
OutDataType,
InDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
WeiDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
AccDataType,
WorkspaceDataType, // C: Workspace normally Out
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ConvConfig::K_Warp_Tile,
GemmPipelineProblem::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
@@ -235,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(shape[1], 1), // Input Stride
ck_tile::make_tuple(shape[1], 1), // Output Stride
input_tensors,
static_cast<WeiDataType*>(c_ptr)));
return ave_time;
};

View File

@@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
@@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge>;
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsType::AsLayoutFwd,
typename GroupedConvTraitsType::BsLayoutFwd,
typename GroupedConvTraitsType::CLayoutFwd,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
@@ -185,8 +180,9 @@ struct GroupedConvolutionForwardInvoker
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
@@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile>>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
false /*EnableSplitImage*/>;
using GroupedConvTraitsTypeDefault =
ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GroupedConvTraitsTypeLargeTensor =
ck_tile::GroupedConvTraits<NDimSpatial,
@@ -64,23 +54,28 @@ struct GroupedConvolutionForwardInvoker
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, /*NumGroupsToMerge*/
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge,
true /*EnableSplitImage*/>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
ConvConfig::TransposeC,
false,
false, // Persistent,
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
@@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
ConvConfig::NumWaveGroups>,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
@@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
VectorSizeA,
VectorSizeB,
VectorSizeC,
GroupedConvTraitsTypeDefault::VectorSizeA,
GroupedConvTraitsTypeDefault::VectorSizeB,
GroupedConvTraitsTypeDefault::VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
@@ -264,21 +260,21 @@ struct GroupedConvolutionForwardInvoker
GroupedConvTraitsTypeLargeTensor,
GroupedConvTraitsTypeDefault>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
@@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
memory_operation,
1,
true,
ConvConfig::NumWaveGroups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
@@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s,
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
// Concept for DL thread configuration
template <typename T>
concept DlThreadConfigDescriptor = requires(T t) {
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m1_per_thread } -> std::convertible_to<size_t>;
{ t.n1_per_thread } -> std::convertible_to<size_t>;
{ t.k_per_thread } -> std::convertible_to<size_t>;
};
// Concept for DL thread cluster
template <typename T>
concept DlThreadClusterDescriptor = requires(T t) {
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
};
// Concept for DL block transfer K0_M0_M1_K1 format
template <typename T>
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL block transfer K0_N0_N1_K1 format
template <typename T>
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for DL C thread transfer
template <typename T>
concept DlCThreadTransferDescriptor = requires(T t) {
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if algorithm specifies DL thread config
template <typename T>
concept SpecifiesDlThreadConfig = requires {
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
};
// Concept to check if algorithm specifies DL thread cluster
template <typename T>
concept SpecifiesDlThreadCluster = requires {
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
};
// Concept to check if algorithm specifies DL A block transfer
template <typename T>
concept SpecifiesDlBlockTransferA = requires {
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
};
// Concept to check if algorithm specifies DL B block transfer
template <typename T>
concept SpecifiesDlBlockTransferB = requires {
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
};
// Concept to check if algorithm specifies DL C thread transfer
template <typename T>
concept SpecifiesDlCThreadTransfer = requires {
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -36,9 +36,21 @@
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
#pragma push_macro("GridwiseGemmTemplateParameters")
#ifdef GridwiseGemmTemplateParameters
#undef GridwiseGemmTemplateParameters
#endif
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#pragma pop_macro("GridwiseGemmTemplateParameters")
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
// of a grouped forward convolution kernel using Direct Load (DL) approach.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
"DL algorithm must specify thread config.");
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
"DL algorithm must specify thread cluster.");
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
"DL algorithm must specify A block transfer.");
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
"DL algorithm must specify B block transfer.");
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
"DL algorithm must specify C thread transfer.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
// DL-specific parameters from algorithm descriptor
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
// Thread cluster from descriptor
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
using ABlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
using BBlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
DL_C_TRANSFER.dst_scalar_per_vector;
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
SPATIAL_DIM,
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
};
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
// of a grouped forward convolution kernel with large tensor support (N-splitting).
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify gemm specialization.");
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
"The convolution algorithm descriptor must specify number of prefetch stages.");
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
"The convolution algorithm descriptor must specify loop scheduler.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION =
factory_internal::SetGemmSpecialization<ALGORITHM>();
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER =
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER =
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance with large tensor support.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
LOOP_SCHEDULER>;
};
} // namespace ck_tile::builder

View File

@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
ConvDirectionIsForward<Sig> &&
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
ConvDirectionIsBackwardWeight<Sig> &&
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
ConvDirectionIsBackwardData<Sig> &&
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);

View File

@@ -0,0 +1,22 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile::builder {
// Enumeration for CK Device Operation types.
// This allows the builder to select which device operation template to instantiate
// based on the user's requirements.
enum class DeviceOpType
{
// Forward Convolution - Non-grouped
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
// Forward Convolution - Grouped
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <string_view>
#include <sstream>
#include <type_traits>
#include <variant>
#include <ck_tile/builder/conv_signature_concepts.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/tree_formatter.hpp>
/// @file conv_description.hpp
/// @brief Provides human-readable descriptions of ConvBuilder configurations
namespace ck_tile::reflect::conv {
struct ConvSignatureInfo
{
int spatial_dim;
builder::ConvDirection direction;
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
layout;
builder::DataType data_type;
builder::ElementwiseOperation input_element_op;
builder::ElementwiseOperation weight_element_op;
builder::ElementwiseOperation output_element_op;
};
// Algorithm information - groups all algorithm-related configuration
struct GemmAlgorithmInfo
{
int thread_block_size;
DataTileInfo tile_dims;
WarpGemmParams warp_gemm;
InputTileTransferInfo a_tile_transfer;
InputTileTransferInfo b_tile_transfer;
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::GemmPadding padding;
};
// Provides human-readable descriptions of ConvBuilder configurations.
struct ConvDescription
{
ConvSignatureInfo signature;
GemmAlgorithmInfo algorithm;
// Brief one-line summary
std::string brief() const
{
std::ostringstream oss;
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
return oss.str();
}
// Detailed hierarchical description
std::string detailed() const
{
TreeFormatter f;
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
f.writeLine(1, "Signature");
f.writeLine(2, "Tensor Type: ", signature.data_type);
f.writeLine(2, "Memory Layout: ", signature.layout);
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
f.writeLine(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
f.writeLine(2,
"Data tile size: ",
algorithm.tile_dims.m,
"×",
algorithm.tile_dims.n,
"×",
algorithm.tile_dims.k);
f.writeLine(2, "Gemm padding: ", algorithm.padding);
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
// Pipeline section
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
f.writeLine(2, "Warp Gemm parameters: ");
f.writeLine(
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
f.writeLast(3,
"Number of warp gemm iterations: ",
algorithm.warp_gemm.m_iter,
"×",
algorithm.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.a_tile_transfer.tile_dimensions.k0,
"×",
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.a_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLine(3, "B Tile transfer: ");
f.writeLine(4,
"Tile dimensions: ",
algorithm.b_tile_transfer.tile_dimensions.k0,
"×",
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
"×",
algorithm.b_tile_transfer.tile_dimensions.k1,
"×");
f.writeLine(
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
f.writeLine(4,
"Spatial thread distribution over the data tile: ",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
f.writeLine(4,
"The order of accessing data tile axes: ",
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
"×",
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
f.writeLine(4,
"Vectorized memory access axis index (with contiguous memory): ",
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
f.writeLine(4,
"Vector access (GMEM read) instruction size: ",
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
f.writeLine(4,
"Vector access (LDS write) instruction size: ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(4,
"LDS data layout padding (to prevent bank conflicts): ",
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
f.writeLast(3, "C Tile transfer: ");
f.writeLine(4,
"Data shuffle (number of gemm instructions per iteration): ",
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
"×",
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
f.writeLine(4,
"Spatial thread distribution used to store data: ",
algorithm.c_tile_transfer.thread_cluster_dims[0],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[1],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[2],
"×",
algorithm.c_tile_transfer.thread_cluster_dims[3]);
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
algorithm.c_tile_transfer.scalar_per_vector);
f.writeLast(2);
f.writeLast(1);
return f.getString();
}
// Educational explanation of optimization choices
std::string explain() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
// Performance characteristics and use case guidance
std::string suggest() const
{
std::ostringstream oss;
// Placeholder for future implementation
return oss.str();
}
};
// Helper concept to detect if a type has InstanceTraits specialization
template <typename T>
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
// Helper concept to detect ConvBuilder types
template <typename T>
concept IsConvBuilder = requires {
typename T::Factory;
typename T::Instance;
};
// Primary factory function: Create ConvDescription from Instance type directly
template <typename Instance>
requires HasInstanceTraits<Instance>
ConvDescription Describe()
{
using Traits = ConvTraits<Instance>;
return ConvDescription{
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
.direction = Traits::direction,
.layout = Traits::layout,
.data_type = Traits::data_type,
.input_element_op = Traits::input_element_op,
.weight_element_op = Traits::weight_element_op,
.output_element_op = Traits::output_element_op},
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
.tile_dims = Traits::tile_dims,
.warp_gemm = Traits::warp_gemm,
.a_tile_transfer = Traits::a_tile_transfer,
.b_tile_transfer = Traits::b_tile_transfer,
.c_tile_transfer = Traits::c_tile_transfer,
.pipeline_version = Traits::pipeline_version,
.pipeline_scheduler = Traits::pipeline_scheduler,
.conv_specialization = Traits::conv_specialization,
.padding = Traits::gemm_padding}};
}
// Backward compatibility: Create ConvDescription from Builder type
template <typename Builder>
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
ConvDescription Describe()
{
// Delegate to Instance-based version
using Instance = typename Builder::Instance;
return Describe<Instance>();
}
} // namespace ck_tile::reflect::conv

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -13,7 +13,10 @@
#include <string_view>
#include <sstream>
#include <type_traits>
#include <climits>
#include <limits.h>
#include <cmath>
#include <ostream>
#include <iostream>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>

View File

@@ -0,0 +1,106 @@
#pragma once
#include <sstream>
#include <string>
#include <type_traits>
#include <vector>
namespace ck_tile::reflect {
// Helper class for formatting hierarchical tree structures with proper indentation
// and tree-drawing characters (├─, └─, │, etc.)
//
// Example Usage:
//
// TreeFormatter f;
// f.writeLine(0, "Root");
// f.writeLine(1, "Branch 1");
// f.writeLine(2, "Item 1a");
// f.writeLast(2, "Item 1b");
// f.writeLast(1, "Branch 2");
// f.writeLast(2, "Item 2a");
// std::cout << f.getString() << "\n";
//
// Generated Output:
//
// Root
// ├─ Branch 1
// │ ├─ Item 1a
// │ └─ Item 1b
// └─ Branch 2
// └─ Item 2a
class TreeFormatter
{
public:
TreeFormatter() = default;
// Write a line at the specified indentation level (branch continues after this)
template <typename... Args>
void writeLine(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
}
// Write the last line at the specified indentation level (branch ends)
template <typename... Args>
void writeLast(int indent_level, Args&&... args)
{
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
}
// Get the formatted string (removes trailing newline if present)
std::string getString() const
{
std::string result = oss_.str();
if(!result.empty() && result.back() == '\n')
{
result.pop_back();
}
return result;
}
private:
std::ostringstream oss_;
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
// Implementation of line writing with tree symbols
template <typename... Args>
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
{
// Ensure we have enough tracking space
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
{
is_last_at_level_.resize(indent_level + 1, false);
// Level 0 (root) should always be treated as "last" since it has no tree symbols
if(is_last_at_level_.size() > 0)
{
is_last_at_level_[0] = true;
}
}
// Draw the tree structure
// Start from level 1 (skip level 0 which is the root with no symbols)
for(int i = 1; i < indent_level; ++i)
{
// For all parent levels, draw vertical line or space based on whether they ended
oss_ << (is_last_at_level_[i] ? " " : "");
}
// Draw the branch symbol for the current level
if(indent_level > 0)
{
oss_ << (is_last ? "└─ " : "├─ ");
}
// Write the content using fold expression with direct stream insertion
((oss_ << std::forward<Args>(args)), ...);
oss_ << '\n';
// Update tracking for this level AFTER writing the line
// This ensures future lines at deeper levels know if this level ended
is_last_at_level_[indent_level] = is_last;
}
};
} // namespace ck_tile::reflect

View File

@@ -3,6 +3,10 @@
#pragma once
#include <ostream>
#include <string_view>
#include <variant>
namespace ck_tile::builder {
enum class DataType
@@ -215,4 +219,275 @@ enum class PipelineScheduler
INTERWAVE
};
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt)
{
using enum DataType;
switch(dt)
{
case FP16: return os << "FP16";
case FP32: return os << "FP32";
case BF16: return os << "BF16";
case FP8: return os << "FP8";
case I8: return os << "I8";
case U8: return os << "U8";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
{
case FORWARD: return os << "Forward";
case BACKWARD_DATA: return os << "Backward Data";
case BACKWARD_WEIGHT: return os << "Backward Weight";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
{
using enum GroupConvLayout1D;
switch(layout)
{
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
{
using enum GroupConvLayout2D;
switch(layout)
{
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
{
using enum GroupConvLayout3D;
switch(layout)
{
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
{
using enum FwdGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
{
using enum BwdDataGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
{
using enum BwdWeightGroupConvDeviceOperation;
switch(op)
{
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
{
case BIAS: return os << "BIAS";
case BIAS_CLAMP: return os << "BIAS_CLAMP";
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
case BILINEAR: return os << "BILINEAR";
case CLAMP: return os << "CLAMP";
case SCALE: return os << "SCALE";
case PASS_THROUGH: return os << "PASS_THROUGH";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
{
case V1: return os << "V1";
case V2: return os << "V2";
case V3: return os << "V3";
case V4: return os << "V4";
case V5: return os << "V5";
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
{
case Default: return os << "Default";
case MPadding: return os << "MPadding";
case NPadding: return os << "NPadding";
case KPadding: return os << "KPadding";
case MNPadding: return os << "MNPadding";
case MKPadding: return os << "MKPadding";
case NKPadding: return os << "NKPadding";
case MNKPadding: return os << "MNKPadding";
case OPadding: return os << "OPadding";
case MOPadding: return os << "MOPadding";
case NOPadding: return os << "NOPadding";
case KOPadding: return os << "KOPadding";
case MNOPadding: return os << "MNOPadding";
case MKOPadding: return os << "MKOPadding";
case NKOPadding: return os << "NKOPadding";
case MNKOPadding: return os << "MNKOPadding";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
{
using enum ConvFwdSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_3x3: return os << "FILTER_3x3";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return os << "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
case ODD_C: return os << "ODD_C";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
{
case DEFAULT: return os << "DEFAULT";
case M_PADDING: return os << "M_PADDING";
case N_PADDING: return os << "N_PADDING";
case K_PADDING: return os << "K_PADDING";
case MN_PADDING: return os << "MN_PADDING";
case MK_PADDING: return os << "MK_PADDING";
case NK_PADDING: return os << "NK_PADDING";
case MNK_PADDING: return os << "MNK_PADDING";
case O_PADDING: return os << "O_PADDING";
case MO_PADDING: return os << "MO_PADDING";
case NO_PADDING: return os << "NO_PADDING";
case KO_PADDING: return os << "KO_PADDING";
case MNO_PADDING: return os << "MNO_PADDING";
case MKO_PADDING: return os << "MKO_PADDING";
case NKO_PADDING: return os << "NKO_PADDING";
case MNKO_PADDING: return os << "MNKO_PADDING";
default: return os << "Unknown";
}
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
{
case DEFAULT: return os << "DEFAULT";
case INTRAWAVE: return os << "INTRAWAVE";
case INTERWAVE: return os << "INTERWAVE";
default: return os << "Unknown";
}
}
// ostream operator overload for std::variant of layout types
inline std::ostream&
operator<<(std::ostream& os,
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
{
std::visit([&os](const auto& l) { os << l; }, layout);
return os;
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
}
} // namespace ck_tile::builder

View File

@@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)
@@ -67,6 +69,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test
add_ck_builder_test(test_conv_traits
conv/test_conv_traits.cpp)
add_ck_builder_test(test_conv_description
test_conv_description.cpp)
# Function to add all test_ckb targets to a list
function(collect_test_ckb_targets result_var)
# Get all targets in current directory

View File

@@ -0,0 +1,69 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -0,0 +1,53 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::DEFAULT>();
}
TEST(
FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
FwdConvSignature,
FwdThreadBlock,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}
} // namespace ck_tile::builder::testing

View File

@@ -214,4 +214,84 @@ static_assert(
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
// DL-specific descriptors
struct DlThreadConfig
{
size_t k0_per_block;
size_t k1;
size_t m1_per_thread;
size_t n1_per_thread;
size_t k_per_thread;
};
static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
struct DlThreadCluster
{
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
struct DlBlockTransferK0M0M1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
struct DlBlockTransferK0N0N1K1
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
struct DlCThreadTransfer
{
std::array<size_t, 6> src_dst_access_order;
size_t src_dst_vector_dim;
size_t dst_scalar_per_vector;
};
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
{
ThreadBlock thread_block;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
DlThreadConfig dl_thread_config;
DlThreadCluster dl_thread_cluster;
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
DlCThreadTransfer dl_c_thread_transfer;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
static_assert(
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/conv_builder.hpp>
#include <ck_tile/builder/reflect/conv_description.hpp>
#include "testing_utils.hpp"
#include "impl/conv_signature_types.hpp"
#include "impl/conv_algorithm_types.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckr = ck_tile::reflect::conv;
namespace ckt = ck_tile::test;
// Defines the signature of the convolution operation to be tested.
// This includes dimensionality, direction, data layout, and data type.
struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
ckb::DataType data_type = ckb::DataType::FP16;
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
ckb::GroupConvDeviceOp device_operation =
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
struct DefaultAlgorithm
{
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
ckb::test::BlockTransferABC block_transfer{
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = true,
.lds_padding = false},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {.order = {0, 1, 2}},
.block_transfer_access_order_b = {.order = {0, 1, 2}},
.src_access_order_a = {.order = {0, 1, 2}},
.src_access_order_b = {.order = {0, 1, 2}}};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
}
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
{
static constexpr const ConvSignature SIGNATURE;
static constexpr const DefaultAlgorithm ALGORITHM;
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
ckt::StringEqWithDiff( //
"2D Forward Convolution Kernel\n"
"├─ Signature\n"
"│ ├─ Tensor Type: FP16\n"
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"├─ Algorithm\n"
"│ ├─ Thread block size: 256\n"
"│ ├─ Data tile size: 256×256×32\n"
"│ ├─ Gemm padding: DEFAULT\n"
"│ ├─ Convolution specialization: DEFAULT\n"
"│ ├─ Pipeline version: V4\n"
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
"│ ├─ Warp Gemm parameters: \n"
"│ │ ├─ subtile size: 16×16\n"
"│ │ └─ Number of warp gemm iterations: 4×4\n"
"│ ├─ Memory access:\n"
"│ │ ├─ A Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ ├─ B Tile transfer: \n"
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
"│ │ │ ├─ The innermost K subdimension size: 8\n"
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
"│ │ └─ C Tile transfer: \n"
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
"│ └─ \n"
"└─ "));
}
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
// does not have a specialization for backward data convolutions. The test fails with:
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
//
// To enable this test, a ConvFactory specialization for backward data operations must be
// implemented first.
//
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
// {
// struct BackwardDataSignature
// {
// int spatial_dim = 2;
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
// ckb::DataType data_type = ckb::DataType::FP16;
// ckb::ElementwiseOperation elementwise_operation =
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
// };
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
//
// static constexpr const BackwardDataSignature SIGNATURE;
// static constexpr const DefaultAlgorithm ALGORITHM;
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
//
// // Verify Brief works
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
// ckt::StringEqWithDiff("2D Backward Data convolution"));
//
// // Verify detailed works - to be updated once ConvFactory is implemented
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
// ckt::StringEqWithDiff("PLACEHOLDER"));
// }
} // namespace

View File

@@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
EXPECT_NE(invoker_ptr, nullptr);
}
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
{
// DL thread configuration
constexpr DlThreadConfig DlThreadCfg{
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
// DL thread cluster
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
// DL A block transfer - K0_M0_M1_K1 format
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL B block transfer - K0_N0_N1_K1 format
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
// DL C thread transfer
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4};
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.dl_thread_config = DlThreadCfg,
.dl_thread_cluster = DlCluster,
.dl_block_transfer_a = DlBlockTransferA,
.dl_block_transfer_b = DlBlockTransferB,
.dl_c_thread_transfer = DlCTransfer};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
{
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 1};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 16,
.n_block = 1,
.n_wave_per_xdl = 4},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = true},
.epilogue_c = {.m_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
// Large_Tensor uses the same descriptor as regular XDL CShuffle
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
.thread_block = FwdThreadBlock,
.gridwise_gemm = FwdGemmParams,
.block_transfer = FwdBlockTransfer,
.fwd_specialization = FwdConvSpecialization,
.gemm_specialization = GemmSpecialization::MNKPadding,
.num_gemm_k_prefetch_stages = 1,
.num_groups_to_merge = 1,
.loop_scheduler = LoopScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
} // namespace ck_tile::builder::test_utils

View File

@@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
});
});
HotLoopScheduler();
if constexpr(MPerBlock >= 64)
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
namespace ck {
@@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
}
else
{
return nullptr;
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)

View File

@@ -0,0 +1,891 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::A_K1;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_m3_k;
using Base::b_block_desc_n0_n1_n2_n3_k;
using Base::AMmaKStride;
using Base::APackedSize;
using Base::BMmaKStride;
using Base::BPackedSize;
using Base::KThreadChunk;
using Base::KXdlPack;
using Base::MXdlPack;
using Base::NXdlPack;
using AccType = typename Base::AccType;
using Tuple5 = typename Base::Tuple5;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
static constexpr auto async_vmcnt =
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
using mx_scale_t = e8m0_bexp_t;
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves +
num_buffer_load_a_scale + num_buffer_load_b_scale;
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
{
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
}
else
{
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
}
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A local
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
[&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
});
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
ignore = b_block_bufs;
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefetch 1, sync the async load
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc,
b_block_origin_idx,
b_thread_bufs(scale_mem_buf));
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(scale_mem_buf));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(scale_mem_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(
make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size>
a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size>
b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(
scale_comp_buf)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(
scale_comp_buf)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_bufs(I1));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// Prefetch b_scales
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_bufs(I1));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
});
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
constexpr auto im_major = m0 / MXdlPack;
constexpr auto im_minor = m0 % MXdlPack;
static_for<0, KRepeat, 1>{}([&](auto k0) {
constexpr auto ik_major = k0 / KXdlPack;
constexpr auto ik_minor = k0 % KXdlPack;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto in_major = n0 / NXdlPack;
constexpr auto in_minor = n0 % NXdlPack;
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops "
"per Thread.");
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
});
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a =
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b =
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(im_major, in_major, im_minor, in_minor, 0));
// MFMA accumulation
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
ik_minor * NXdlPack + in_minor>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat / MXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat / NXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
} // namespace ck

View File

@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
constexpr auto num_total_stages = std::max(2, MRepeat);
if constexpr(num_total_stages > 2)
{
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_perstage_more =
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less =
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
constexpr auto buffer_load_perstage_stage2 =
math::integer_divide_floor((num_buffer_load_stage2), 2);
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_stages_more =
num_buffer_load_stage1 -
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto buffer_load_issue_point_interval_stage2 =
num_mfma_perstage / buffer_load_perstage_stage2;
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// Stage 1
// global read more
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(
0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
else
{
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale +
num_buffer_load_b_scale;
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
constexpr auto mfma_perstage_more =
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less =
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_stages_more =
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// global read less
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
else
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
});
});
// Stage 2, Sync
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
}
template <bool HasMainLoop,

View File

@@ -122,7 +122,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
math::max(2, NXdlPerWave_),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,

View File

@@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
return make_naive_tensor_descriptor_packed(
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
return make_naive_tensor_descriptor_packed(make_tuple(
math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber));
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(

View File

@@ -48,28 +48,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
}
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
@@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
// NPerBlock>;
#if 0
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
@@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = a_element_op;
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
problem.MPadded,
@@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
problem.NPadded,
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
// We pad the M unconditionaly for Scale
const auto Padded_Scale_M =
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
64 * KXdlPack * MXdlPack / scale_pack_size_a),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / APackedSize)) *
MPerXdl * MXdlPack / scale_pack_size_a,
64 * KXdlPack * MXdlPack / scale_pack_size_a,
1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
64 * KXdlPack * NXdlPack / scale_pack_size_b),
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
(ScaleBlockSize / BPackedSize)) *
NPerXdl * NXdlPack / scale_pack_size_b,
64 * KXdlPack * NXdlPack / scale_pack_size_b,
1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(expert_block_id * MPerBlock >= max_token_id)
return;
const index_t expert_id =
__builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr(NSwizzle)
{
@@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
index_t token_offset = fused_token & 0xffffff;
if constexpr(!IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride =
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize));
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
problem.N * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
// Gride buffer creation
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid + expert_id * expert_stride / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
// A, B scale buffer
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid + expert_id * expert_scale_stride,
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
// A matrix blockwise direct to LDS copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
1>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
auto b_blockwise_copy =
ThreadwiseTensorSliceTransfer_v2<BDataType,
@@ -1463,7 +1463,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
Sequence<0, 1, 2, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
@@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
0,
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
@@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
auto thread_offset_shuffled =
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
auto a_thread_offset_m = waveId_m;
// get each thread's offset int the scale tensor
const index_t token_scale_pos = block_m_id * MPerBlock;
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
return;
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
AScaleDataType,
AScaleDataType,
@@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true>(b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
@@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * expert_stride / BPackedSize,
p_b_grid_up + expert_id * expert_stride,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
auto b_blockwise_copy_up =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave / NXdlPack>{},
I1,
Number<NXdlPack>{},
Number<KRepeat>{},
Number<BK1Value>{}>,
Sequence<0, 1, 2, 3, 4>,
4,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
0,
KPack * (get_thread_local_1d_id() % WarpSize)));
const BScaleDataType* p_b_scale_grid_up =
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
BScaleDataType,
BScaleDataType,
@@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
thread_offset_shuffled / scale_pack_size_b));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
// A
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
// Gate and Up
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_buf,
b_block_bufs,
b_block_slice_copy_step,
// C
c_thread_buf,
c_thread_buf_up,
// A scale
a_scale_grid_desc_am_ak,
a_scale_thread_copy,
a_scale_grid_buf,
// B scale
b_scale_grid_desc_bn_ak,
b_scale_thread_copy,
b_scale_thread_copy_up,
@@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_grid_desc_ak0_m_ak1, // A
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_grid_desc_bpreshuffled, // B
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_bufs,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
c_thread_buf, // C
a_scale_grid_desc_am_ak, // A scale
a_scale_thread_copy,
a_scale_grid_buf,
b_scale_grid_desc_bn_ak,
b_scale_grid_desc_bn_ak, // B scale
b_scale_thread_copy,
b_scale_grid_buf,
num_k_block_main_loop);
@@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
// mul scales
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
static_assert(M5 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock +
m0 * M2 * M1 * M3 * M4 * M5 +
m1 * M2 * M3 * M4 * M5 +
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
topk_weights =
*c_style_pointer_cast<const vector_type<float, M5>*>(
p_ds_grid[I2] + m_pos);
}
}
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation ==
Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m5];
up = up * topk_weights.AsType<float>()[m5];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) =
topk_weights.AsType<float>()[m5] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
});
@@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
// shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
M4,
M5)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
// per shuffle
N1, // N1 = NWave
N2, // N2 = NXdlPack
N3))), // N3 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 6, 7, 8>{},
Sequence<>{},
Sequence<1, 3, 5, 9>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
@@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
@@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
@@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
I1,
I1,
M2,
N2,
M3,
I1,
M5,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
m_thread_data_on_block_idx[I5],
n_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
@@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = 1; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
@@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
NXdlPerWave / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
CShuffleNXdlPerWavePerShuffle / NXdlPack,
1,
1,
MXdlPack,
NXdlPack,
M2,
1,
M4,
@@ -1984,7 +2039,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
});
}
}
#endif
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,

View File

@@ -74,6 +74,21 @@ struct GroupedConvTraits
}
public:
// Fixed values for Implicit GEMM
struct FixedGemmParams
{
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool TransposeC = false;
static constexpr bool FixedVectorSize = true;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
};
// Compile time parameters
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
static constexpr index_t NDimSpatial = NDimSpatial_;
@@ -82,31 +97,43 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Weight Gemm Layouts
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdData =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::ColumnMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
true,
true,
AsLayoutBwdData,
BsLayoutBwdData,
CLayoutBwdData,
NumWaveGroups>;
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
true,
true,
AsLayoutBwdWeight,
BsLayoutBwdWeight,
CLayoutBwdWeight,
NumWaveGroups>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
static constexpr index_t NumDTensor = DsLayout::size();
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};