mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'develop' into jograner/conv-bwd-weight-kbatch-in-2gb-limit
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -66,6 +66,12 @@ docs/doxygen/xml
|
||||
cmake-build*/
|
||||
build*/
|
||||
|
||||
# LSP configuration
|
||||
.clangd
|
||||
|
||||
# User-defined CMake presets
|
||||
CMakeUserPresets.json
|
||||
|
||||
# Python virtualenv
|
||||
.venv/
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
@@ -190,10 +190,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 64, KPerBlock,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
2, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t N = 7168;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t tokens = 208;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
|
||||
@@ -14,19 +14,11 @@
|
||||
|
||||
struct ConvConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
static constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
static constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
@@ -210,9 +202,9 @@ struct ConvConfigComputeV5 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -22,8 +22,6 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -32,36 +30,33 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdData,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdData,
|
||||
typename GroupedConvTraitsType::CLayoutBwdData,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -93,95 +89,96 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
InDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
auto preprocess = [&]() {
|
||||
ck_tile::hip_check_error(hipMemsetAsync(
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
|
||||
@@ -21,8 +21,6 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
static float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -31,37 +29,34 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
|
||||
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
|
||||
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -69,13 +64,14 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -101,21 +97,21 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -127,7 +123,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
AccDataType,
|
||||
WeiDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -136,10 +132,10 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
@@ -184,7 +180,7 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -23,8 +23,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
using WorkspaceDataType = float;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -33,36 +31,34 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 4;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::BsLayoutBwdWeight,
|
||||
typename GroupedConvTraitsType::CLayoutBwdWeight,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -70,13 +66,14 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdWeight,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -102,21 +99,21 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -128,7 +125,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
AccDataType,
|
||||
WorkspaceDataType, // C: Workspace normally Out
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -139,8 +136,8 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GemmPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
@@ -235,16 +232,17 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(shape[1], 1), // Input Stride
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -32,8 +32,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
|
||||
@@ -42,38 +40,34 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
constexpr ck_tile::index_t NumGroupsToMerge = 1;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
NumGroupsToMerge>;
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsType::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsType::AsLayoutFwd,
|
||||
typename GroupedConvTraitsType::BsLayoutFwd,
|
||||
typename GroupedConvTraitsType::CLayoutFwd,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -81,13 +75,14 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -116,21 +111,21 @@ struct GroupedConvolutionForwardInvoker
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -142,7 +137,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -151,10 +146,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
@@ -185,8 +180,9 @@ struct GroupedConvolutionForwardInvoker
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -25,7 +25,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -35,27 +34,18 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile>>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
ConvConfig::TileParitionerGroupNum,
|
||||
ConvConfig::TileParitionerM01>;
|
||||
|
||||
using GroupedConvTraitsTypeDefault = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
1, /*NumGroupsToMerge*/
|
||||
false /*EnableSplitImage*/>;
|
||||
using GroupedConvTraitsTypeDefault =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge>;
|
||||
|
||||
using GroupedConvTraitsTypeLargeTensor =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
@@ -64,23 +54,28 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
1, /*NumGroupsToMerge*/
|
||||
ConvConfig::VectorSizeA,
|
||||
ConvConfig::VectorSizeB,
|
||||
ConvConfig::VectorSizeC,
|
||||
ConvConfig::NumGroupsToMerge,
|
||||
true /*EnableSplitImage*/>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
|
||||
GemmShape,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerGroupNum,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TilePartitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
ConvConfig::kPadM,
|
||||
ConvConfig::kPadN,
|
||||
ConvConfig::kPadK,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadM,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadN,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::kPadK,
|
||||
ConvConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
ConvConfig::TransposeC,
|
||||
false,
|
||||
false, // Persistent,
|
||||
typename GroupedConvTraitsTypeDefault::AsLayoutFwd,
|
||||
typename GroupedConvTraitsTypeDefault::BsLayoutFwd,
|
||||
typename GroupedConvTraitsTypeDefault::CLayoutFwd,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::TransposeC,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
@@ -88,13 +83,14 @@ struct GroupedConvolutionForwardInvoker
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsTypeDefault::GroupedConvImplicitGemmTraitsFwd,
|
||||
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeA,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
@@ -116,9 +112,9 @@ struct GroupedConvolutionForwardInvoker
|
||||
using TransformType =
|
||||
ck_tile::TransformConvFwdToGemm<NDimSpatial,
|
||||
ck_tile::ConvolutionSpecialization::Default,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeA,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeB,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeC,
|
||||
1, // NumGroupsToMerge
|
||||
false, // SplitN
|
||||
InDataType,
|
||||
@@ -264,21 +260,21 @@ struct GroupedConvolutionForwardInvoker
|
||||
GroupedConvTraitsTypeLargeTensor,
|
||||
GroupedConvTraitsTypeDefault>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -290,7 +286,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -299,10 +295,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
ConvConfig::TransposeC,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
// Use split-image kernel if layout supports it, otherwise use regular kernel
|
||||
@@ -368,7 +364,8 @@ struct GroupedConvolutionForwardInvoker
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -78,66 +78,4 @@ struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -183,4 +183,87 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
|
||||
// Concept for DL thread configuration
|
||||
template <typename T>
|
||||
concept DlThreadConfigDescriptor = requires(T t) {
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.n1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k_per_thread } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for DL thread cluster
|
||||
template <typename T>
|
||||
concept DlThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
{ t.n1_xs } -> std::convertible_to<std::array<size_t, 2>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_M0_M1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0M0M1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL block transfer K0_N0_N1_K1 format
|
||||
template <typename T>
|
||||
concept DlBlockTransferK0N0N1K1Descriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for DL C thread transfer
|
||||
template <typename T>
|
||||
concept DlCThreadTransferDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadConfig = requires {
|
||||
{ T::dl_thread_config } -> DlThreadConfigDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread cluster
|
||||
template <typename T>
|
||||
concept SpecifiesDlThreadCluster = requires {
|
||||
{ T::dl_thread_cluster } -> DlThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL A block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferA = requires {
|
||||
{ T::dl_block_transfer_a } -> DlBlockTransferK0M0M1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL B block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransferB = requires {
|
||||
{ T::dl_block_transfer_b } -> DlBlockTransferK0N0N1K1Descriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlCThreadTransfer = requires {
|
||||
{ T::dl_c_thread_transfer } -> DlCThreadTransferDescriptor;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -36,9 +36,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
// WORKAROUND: Macro namespace collision in upstream CK device operation headers.
|
||||
// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and
|
||||
// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define
|
||||
// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors.
|
||||
// Use pragma push/pop to isolate the Large_Tensor header's macro scope.
|
||||
#pragma push_macro("GridwiseGemmTemplateParameters")
|
||||
#ifdef GridwiseGemmTemplateParameters
|
||||
#undef GridwiseGemmTemplateParameters
|
||||
#endif
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#pragma pop_macro("GridwiseGemmTemplateParameters")
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
@@ -990,4 +1002,263 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance
|
||||
// of a grouped forward convolution kernel using Direct Load (DL) approach.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesDlThreadConfig<AlgorithmType>,
|
||||
"DL algorithm must specify thread config.");
|
||||
static_assert(SpecifiesDlThreadCluster<AlgorithmType>,
|
||||
"DL algorithm must specify thread cluster.");
|
||||
static_assert(SpecifiesDlBlockTransferA<AlgorithmType>,
|
||||
"DL algorithm must specify A block transfer.");
|
||||
static_assert(SpecifiesDlBlockTransferB<AlgorithmType>,
|
||||
"DL algorithm must specify B block transfer.");
|
||||
static_assert(SpecifiesDlCThreadTransfer<AlgorithmType>,
|
||||
"DL algorithm must specify C thread transfer.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.dl_thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.dl_thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.dl_block_transfer_a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.dl_block_transfer_b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.dl_c_thread_transfer;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance
|
||||
// of a grouped forward convolution kernel with large tensor support (N-splitting).
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE> &&
|
||||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseXdlGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
static_assert(SpecifiesGemmSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gemm specialization.");
|
||||
static_assert(SpecifiesNumPrefetchStages<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify number of prefetch stages.");
|
||||
static_assert(SpecifiesLoopScheduler<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify loop scheduler.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
factory_internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -33,30 +33,35 @@ concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWAR
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
ConvDirectionIsForward<Sig> &&
|
||||
(Sig.device_operation._fwd ==
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
|
||||
|
||||
@@ -76,48 +81,56 @@ concept ConvDeviceOpIsForward =
|
||||
// Predicate for DeviceGroupedConvBwdWeight operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvDirectionIsBackwardWeight<Sig> &&
|
||||
(Sig.device_operation._bwd_weight ==
|
||||
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
|
||||
|
||||
@@ -140,18 +153,21 @@ concept ConvDeviceOpIsBackwardWeight =
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
|
||||
|
||||
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
|
||||
template <auto Sig>
|
||||
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
|
||||
ConvDirectionIsBackwardData<Sig> &&
|
||||
(Sig.device_operation._bwd_data ==
|
||||
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Enumeration for CK Device Operation types.
|
||||
// This allows the builder to select which device operation template to instantiate
|
||||
// based on the user's requirements.
|
||||
enum class DeviceOpType
|
||||
{
|
||||
// Forward Convolution - Non-grouped
|
||||
CONV_FWD, // Maps to: DeviceConvFwd (TODO: No implementation with tuning params exists yet)
|
||||
|
||||
// Forward Convolution - Grouped
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD, // Maps to: DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
GROUPED_CONV_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_V3, // Maps to:
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,268 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
|
||||
#include <ck_tile/builder/conv_signature_concepts.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_traits.hpp>
|
||||
#include <ck_tile/builder/reflect/tree_formatter.hpp>
|
||||
|
||||
/// @file conv_description.hpp
|
||||
/// @brief Provides human-readable descriptions of ConvBuilder configurations
|
||||
|
||||
namespace ck_tile::reflect::conv {
|
||||
|
||||
struct ConvSignatureInfo
|
||||
{
|
||||
int spatial_dim;
|
||||
builder::ConvDirection direction;
|
||||
std::variant<builder::GroupConvLayout1D, builder::GroupConvLayout2D, builder::GroupConvLayout3D>
|
||||
layout;
|
||||
builder::DataType data_type;
|
||||
builder::ElementwiseOperation input_element_op;
|
||||
builder::ElementwiseOperation weight_element_op;
|
||||
builder::ElementwiseOperation output_element_op;
|
||||
};
|
||||
|
||||
// Algorithm information - groups all algorithm-related configuration
|
||||
struct GemmAlgorithmInfo
|
||||
{
|
||||
int thread_block_size;
|
||||
DataTileInfo tile_dims;
|
||||
WarpGemmParams warp_gemm;
|
||||
InputTileTransferInfo a_tile_transfer;
|
||||
InputTileTransferInfo b_tile_transfer;
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
std::variant<builder::ConvFwdSpecialization,
|
||||
builder::ConvBwdDataSpecialization,
|
||||
builder::ConvBwdWeightSpecialization>
|
||||
conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
// Provides human-readable descriptions of ConvBuilder configurations.
|
||||
struct ConvDescription
|
||||
{
|
||||
ConvSignatureInfo signature;
|
||||
GemmAlgorithmInfo algorithm;
|
||||
|
||||
// Brief one-line summary
|
||||
std::string brief() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << signature.spatial_dim << "D " << signature.direction << " convolution";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Detailed hierarchical description
|
||||
std::string detailed() const
|
||||
{
|
||||
TreeFormatter f;
|
||||
f.writeLine(0, signature.spatial_dim, "D ", signature.direction, " Convolution Kernel");
|
||||
f.writeLine(1, "Signature");
|
||||
f.writeLine(2, "Tensor Type: ", signature.data_type);
|
||||
f.writeLine(2, "Memory Layout: ", signature.layout);
|
||||
f.writeLine(2, "Input elementwise operation: ", signature.input_element_op);
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
|
||||
|
||||
f.writeLine(1, "Algorithm");
|
||||
// Compute Block section
|
||||
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
|
||||
f.writeLine(2,
|
||||
"Data tile size: ",
|
||||
algorithm.tile_dims.m,
|
||||
"×",
|
||||
algorithm.tile_dims.n,
|
||||
"×",
|
||||
algorithm.tile_dims.k);
|
||||
f.writeLine(2, "Gemm padding: ", algorithm.padding);
|
||||
f.writeLine(2, "Convolution specialization: ", algorithm.conv_specialization);
|
||||
// Pipeline section
|
||||
f.writeLine(2, "Pipeline version: ", algorithm.pipeline_version);
|
||||
f.writeLine(2, "Pipeline scheduler: ", algorithm.pipeline_scheduler);
|
||||
f.writeLine(2, "Warp Gemm parameters: ");
|
||||
f.writeLine(
|
||||
3, "subtile size: ", algorithm.warp_gemm.gemm_m, "×", algorithm.warp_gemm.gemm_n);
|
||||
f.writeLast(3,
|
||||
"Number of warp gemm iterations: ",
|
||||
algorithm.warp_gemm.m_iter,
|
||||
"×",
|
||||
algorithm.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLine(2, "Memory access:");
|
||||
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.a_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.a_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.a_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLine(3, "B Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Tile dimensions: ",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k0,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.m_or_n,
|
||||
"×",
|
||||
algorithm.b_tile_transfer.tile_dimensions.k1,
|
||||
"×");
|
||||
f.writeLine(
|
||||
4, "The innermost K subdimension size: ", algorithm.b_tile_transfer.transfer_params.k1);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution over the data tile: ",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.thread_cluster_order[2]);
|
||||
f.writeLine(4,
|
||||
"The order of accessing data tile axes: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[0],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[1],
|
||||
"×",
|
||||
algorithm.b_tile_transfer.transfer_params.src_access_order[2]);
|
||||
f.writeLine(4,
|
||||
"Vectorized memory access axis index (with contiguous memory): ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_vector_dim);
|
||||
f.writeLine(4,
|
||||
"Vector access (GMEM read) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.src_scalar_per_vector);
|
||||
f.writeLine(4,
|
||||
"Vector access (LDS write) instruction size: ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
f.writeLast(4,
|
||||
"LDS data layout padding (to prevent bank conflicts): ",
|
||||
algorithm.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1);
|
||||
|
||||
f.writeLast(3, "C Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
"Data shuffle (number of gemm instructions per iteration): ",
|
||||
algorithm.c_tile_transfer.shuffle_params.m_gemms_per_shuffle,
|
||||
"×",
|
||||
algorithm.c_tile_transfer.shuffle_params.n_gemms_per_shuffle);
|
||||
f.writeLine(4,
|
||||
"Spatial thread distribution used to store data: ",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[0],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[1],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[2],
|
||||
"×",
|
||||
algorithm.c_tile_transfer.thread_cluster_dims[3]);
|
||||
f.writeLast(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
algorithm.c_tile_transfer.scalar_per_vector);
|
||||
f.writeLast(2);
|
||||
f.writeLast(1);
|
||||
return f.getString();
|
||||
}
|
||||
|
||||
// Educational explanation of optimization choices
|
||||
std::string explain() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Performance characteristics and use case guidance
|
||||
std::string suggest() const
|
||||
{
|
||||
std::ostringstream oss;
|
||||
// Placeholder for future implementation
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// Helper concept to detect if a type has InstanceTraits specialization
|
||||
template <typename T>
|
||||
concept HasInstanceTraits = requires { typename InstanceTraits<T>; };
|
||||
|
||||
// Helper concept to detect ConvBuilder types
|
||||
template <typename T>
|
||||
concept IsConvBuilder = requires {
|
||||
typename T::Factory;
|
||||
typename T::Instance;
|
||||
};
|
||||
|
||||
// Primary factory function: Create ConvDescription from Instance type directly
|
||||
template <typename Instance>
|
||||
requires HasInstanceTraits<Instance>
|
||||
ConvDescription Describe()
|
||||
{
|
||||
using Traits = ConvTraits<Instance>;
|
||||
|
||||
return ConvDescription{
|
||||
.signature = ConvSignatureInfo{.spatial_dim = Traits::spatial_dim,
|
||||
.direction = Traits::direction,
|
||||
.layout = Traits::layout,
|
||||
.data_type = Traits::data_type,
|
||||
.input_element_op = Traits::input_element_op,
|
||||
.weight_element_op = Traits::weight_element_op,
|
||||
.output_element_op = Traits::output_element_op},
|
||||
.algorithm = GemmAlgorithmInfo{.thread_block_size = Traits::thread_block_size,
|
||||
.tile_dims = Traits::tile_dims,
|
||||
.warp_gemm = Traits::warp_gemm,
|
||||
.a_tile_transfer = Traits::a_tile_transfer,
|
||||
.b_tile_transfer = Traits::b_tile_transfer,
|
||||
.c_tile_transfer = Traits::c_tile_transfer,
|
||||
.pipeline_version = Traits::pipeline_version,
|
||||
.pipeline_scheduler = Traits::pipeline_scheduler,
|
||||
.conv_specialization = Traits::conv_specialization,
|
||||
.padding = Traits::gemm_padding}};
|
||||
}
|
||||
|
||||
// Backward compatibility: Create ConvDescription from Builder type
|
||||
template <typename Builder>
|
||||
requires IsConvBuilder<Builder> && (!HasInstanceTraits<Builder>)
|
||||
ConvDescription Describe()
|
||||
{
|
||||
// Delegate to Instance-based version
|
||||
using Instance = typename Builder::Instance;
|
||||
return Describe<Instance>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::conv
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -13,7 +13,10 @@
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <climits>
|
||||
#include <limits.h>
|
||||
#include <cmath>
|
||||
#include <ostream>
|
||||
#include <iostream>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
// Helper class for formatting hierarchical tree structures with proper indentation
|
||||
// and tree-drawing characters (├─, └─, │, etc.)
|
||||
//
|
||||
// Example Usage:
|
||||
//
|
||||
// TreeFormatter f;
|
||||
// f.writeLine(0, "Root");
|
||||
// f.writeLine(1, "Branch 1");
|
||||
// f.writeLine(2, "Item 1a");
|
||||
// f.writeLast(2, "Item 1b");
|
||||
// f.writeLast(1, "Branch 2");
|
||||
// f.writeLast(2, "Item 2a");
|
||||
// std::cout << f.getString() << "\n";
|
||||
//
|
||||
// Generated Output:
|
||||
//
|
||||
// Root
|
||||
// ├─ Branch 1
|
||||
// │ ├─ Item 1a
|
||||
// │ └─ Item 1b
|
||||
// └─ Branch 2
|
||||
// └─ Item 2a
|
||||
class TreeFormatter
|
||||
{
|
||||
public:
|
||||
TreeFormatter() = default;
|
||||
|
||||
// Write a line at the specified indentation level (branch continues after this)
|
||||
template <typename... Args>
|
||||
void writeLine(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, false, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Write the last line at the specified indentation level (branch ends)
|
||||
template <typename... Args>
|
||||
void writeLast(int indent_level, Args&&... args)
|
||||
{
|
||||
writeLineImpl(indent_level, true, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
// Get the formatted string (removes trailing newline if present)
|
||||
std::string getString() const
|
||||
{
|
||||
std::string result = oss_.str();
|
||||
if(!result.empty() && result.back() == '\n')
|
||||
{
|
||||
result.pop_back();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream oss_;
|
||||
std::vector<bool> is_last_at_level_; // Tracks which levels have ended
|
||||
|
||||
// Implementation of line writing with tree symbols
|
||||
template <typename... Args>
|
||||
void writeLineImpl(int indent_level, bool is_last, Args&&... args)
|
||||
{
|
||||
// Ensure we have enough tracking space
|
||||
if(static_cast<size_t>(indent_level) >= is_last_at_level_.size())
|
||||
{
|
||||
is_last_at_level_.resize(indent_level + 1, false);
|
||||
// Level 0 (root) should always be treated as "last" since it has no tree symbols
|
||||
if(is_last_at_level_.size() > 0)
|
||||
{
|
||||
is_last_at_level_[0] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Draw the tree structure
|
||||
// Start from level 1 (skip level 0 which is the root with no symbols)
|
||||
for(int i = 1; i < indent_level; ++i)
|
||||
{
|
||||
// For all parent levels, draw vertical line or space based on whether they ended
|
||||
oss_ << (is_last_at_level_[i] ? " " : "│ ");
|
||||
}
|
||||
|
||||
// Draw the branch symbol for the current level
|
||||
if(indent_level > 0)
|
||||
{
|
||||
oss_ << (is_last ? "└─ " : "├─ ");
|
||||
}
|
||||
|
||||
// Write the content using fold expression with direct stream insertion
|
||||
((oss_ << std::forward<Args>(args)), ...);
|
||||
|
||||
oss_ << '\n';
|
||||
|
||||
// Update tracking for this level AFTER writing the line
|
||||
// This ensures future lines at deeper levels know if this level ended
|
||||
is_last_at_level_[indent_level] = is_last;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -3,6 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
@@ -215,4 +219,275 @@ enum class PipelineScheduler
|
||||
INTERWAVE
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return os << "FP16";
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
{
|
||||
case FORWARD: return os << "Forward";
|
||||
case BACKWARD_DATA: return os << "Backward Data";
|
||||
case BACKWARD_WEIGHT: return os << "Backward Weight";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout1D layout)
|
||||
{
|
||||
using enum GroupConvLayout1D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNWC_GKXC_GNWK: return os << "GNWC_GKXC_GNWK";
|
||||
case NWGC_GKXC_NWGK: return os << "NWGC_GKXC_NWGK";
|
||||
case NGCW_GKXC_NGKW: return os << "NGCW_GKXC_NGKW";
|
||||
case NGCW_GKCX_NGKW: return os << "NGCW_GKCX_NGKW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout2D layout)
|
||||
{
|
||||
using enum GroupConvLayout2D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNHWC_GKYXC_GNHWK: return os << "GNHWC_GKYXC_GNHWK";
|
||||
case NHWGC_GKYXC_NHWGK: return os << "NHWGC_GKYXC_NHWGK";
|
||||
case NGCHW_GKYXC_NGKHW: return os << "NGCHW_GKYXC_NGKHW";
|
||||
case NGCHW_GKCYX_NGKHW: return os << "NGCHW_GKCYX_NGKHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GroupConvLayout3D layout)
|
||||
{
|
||||
using enum GroupConvLayout3D;
|
||||
switch(layout)
|
||||
{
|
||||
case GNDHWC_GKZYXC_GNDHWK: return os << "GNDHWC_GKZYXC_GNDHWK";
|
||||
case NDHWGC_GKZYXC_NDHWGK: return os << "NDHWGC_GKZYXC_NDHWGK";
|
||||
case NGCDHW_GKZYXC_NGKDHW: return os << "NGCDHW_GKZYXC_NGKDHW";
|
||||
case NGCDHW_GKCZYX_NGKDHW: return os << "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, FwdGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum FwdGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK:
|
||||
return os << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK";
|
||||
case DeviceGroupedConvFwdMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle";
|
||||
case DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3:
|
||||
return os << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3";
|
||||
case DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor:
|
||||
return os << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdDataGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdDataGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdDataMultipleD: return os << "DeviceGroupedConvBwdDataMultipleD";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1:
|
||||
return os << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, BwdWeightGroupConvDeviceOperation op)
|
||||
{
|
||||
using enum BwdWeightGroupConvDeviceOperation;
|
||||
switch(op)
|
||||
{
|
||||
case DeviceGroupedConvBwdWeight: return os << "DeviceGroupedConvBwdWeight";
|
||||
case DeviceGroupedConvBwdWeight_Dl: return os << "DeviceGroupedConvBwdWeight_Dl";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeight_Xdl_CShuffleV3:
|
||||
return os << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
|
||||
case DeviceGroupedConvBwdWeight_Wmma_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeight_Wmma_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle";
|
||||
case DeviceGroupedConvBwdWeightMultipleD: return os << "DeviceGroupedConvBwdWeightMultipleD";
|
||||
case DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle:
|
||||
return os << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case BIAS: return os << "BIAS";
|
||||
case BIAS_CLAMP: return os << "BIAS_CLAMP";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case BILINEAR: return os << "BILINEAR";
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case V1: return os << "V1";
|
||||
case V2: return os << "V2";
|
||||
case V3: return os << "V3";
|
||||
case V4: return os << "V4";
|
||||
case V5: return os << "V5";
|
||||
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return os << "Default";
|
||||
case MPadding: return os << "MPadding";
|
||||
case NPadding: return os << "NPadding";
|
||||
case KPadding: return os << "KPadding";
|
||||
case MNPadding: return os << "MNPadding";
|
||||
case MKPadding: return os << "MKPadding";
|
||||
case NKPadding: return os << "NKPadding";
|
||||
case MNKPadding: return os << "MNKPadding";
|
||||
case OPadding: return os << "OPadding";
|
||||
case MOPadding: return os << "MOPadding";
|
||||
case NOPadding: return os << "NOPadding";
|
||||
case KOPadding: return os << "KOPadding";
|
||||
case MNOPadding: return os << "MNOPadding";
|
||||
case MKOPadding: return os << "MKOPadding";
|
||||
case NKOPadding: return os << "NKOPadding";
|
||||
case MNKOPadding: return os << "MNKOPadding";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return os << "FILTER_3x3";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case ODD_C: return os << "ODD_C";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case M_PADDING: return os << "M_PADDING";
|
||||
case N_PADDING: return os << "N_PADDING";
|
||||
case K_PADDING: return os << "K_PADDING";
|
||||
case MN_PADDING: return os << "MN_PADDING";
|
||||
case MK_PADDING: return os << "MK_PADDING";
|
||||
case NK_PADDING: return os << "NK_PADDING";
|
||||
case MNK_PADDING: return os << "MNK_PADDING";
|
||||
case O_PADDING: return os << "O_PADDING";
|
||||
case MO_PADDING: return os << "MO_PADDING";
|
||||
case NO_PADDING: return os << "NO_PADDING";
|
||||
case KO_PADDING: return os << "KO_PADDING";
|
||||
case MNO_PADDING: return os << "MNO_PADDING";
|
||||
case MKO_PADDING: return os << "MKO_PADDING";
|
||||
case NKO_PADDING: return os << "NKO_PADDING";
|
||||
case MNKO_PADDING: return os << "MNKO_PADDING";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case INTRAWAVE: return os << "INTRAWAVE";
|
||||
case INTERWAVE: return os << "INTERWAVE";
|
||||
default: return os << "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of layout types
|
||||
inline std::ostream&
|
||||
operator<<(std::ostream& os,
|
||||
const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout)
|
||||
{
|
||||
std::visit([&os](const auto& l) { os << l; }, layout);
|
||||
return os;
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
ConvBwdDataSpecialization,
|
||||
ConvBwdWeightSpecialization>& spec)
|
||||
{
|
||||
std::visit([&os](const auto& s) { os << s; }, spec);
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -43,6 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
@@ -67,6 +69,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test
|
||||
add_ck_builder_test(test_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
add_ck_builder_test(test_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Function to add all test_ckb targets to a list
|
||||
function(collect_test_ckb_targets result_var)
|
||||
# Get all targets in current directory
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
namespace ck_tile::builder::testing {
|
||||
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation =
|
||||
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::testing
|
||||
@@ -214,4 +214,84 @@ static_assert(
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
// DL-specific descriptors
|
||||
struct DlThreadConfig
|
||||
{
|
||||
size_t k0_per_block;
|
||||
size_t k1;
|
||||
size_t m1_per_thread;
|
||||
size_t n1_per_thread;
|
||||
size_t k_per_thread;
|
||||
};
|
||||
static_assert(ckb::DlThreadConfigDescriptor<DlThreadConfig>);
|
||||
|
||||
struct DlThreadCluster
|
||||
{
|
||||
std::array<size_t, 2> m1_xs; // e.g., {8, 2}
|
||||
std::array<size_t, 2> n1_xs; // e.g., {8, 2}
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
struct DlBlockTransferK0M0M1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0M0M1K1Descriptor<DlBlockTransferK0M0M1K1>);
|
||||
|
||||
struct DlBlockTransferK0N0N1K1
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferK0N0N1K1Descriptor<DlBlockTransferK0N0N1K1>);
|
||||
|
||||
struct DlCThreadTransfer
|
||||
{
|
||||
std::array<size_t, 6> src_dst_access_order;
|
||||
size_t src_dst_vector_dim;
|
||||
size_t dst_scalar_per_vector;
|
||||
};
|
||||
static_assert(ckb::DlCThreadTransferDescriptor<DlCThreadTransfer>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
DlThreadConfig dl_thread_config;
|
||||
DlThreadCluster dl_thread_cluster;
|
||||
DlBlockTransferK0M0M1K1 dl_block_transfer_a;
|
||||
DlBlockTransferK0N0N1K1 dl_block_transfer_b;
|
||||
DlCThreadTransfer dl_c_thread_transfer;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadConfig<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlThreadCluster<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferA<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlBlockTransferB<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
static_assert(
|
||||
ckb::SpecifiesDlCThreadTransfer<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
169
experimental/builder/test/test_conv_description.cpp
Normal file
169
experimental/builder/test/test_conv_description.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include <ck_tile/builder/conv_builder.hpp>
|
||||
#include <ck_tile/builder/reflect/conv_description.hpp>
|
||||
#include "testing_utils.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckr = ck_tile::reflect::conv;
|
||||
namespace ckt = ck_tile::test;
|
||||
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
{
|
||||
ckb::test::ThreadBlock thread_block{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
ckb::test::BlockTransferABC block_transfer{
|
||||
.block_transfer_a = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 256, .k1 = 8},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = true,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {.order = {0, 1, 2}},
|
||||
.block_transfer_access_order_b = {.order = {0, 1, 2}},
|
||||
.src_access_order_a = {.order = {0, 1, 2}},
|
||||
.src_access_order_b = {.order = {0, 1, 2}}};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().brief(), ckt::StringEqWithDiff("2D Forward convolution"));
|
||||
}
|
||||
|
||||
TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
{
|
||||
static constexpr const ConvSignature SIGNATURE;
|
||||
static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
ckt::StringEqWithDiff( //
|
||||
"2D Forward Convolution Kernel\n"
|
||||
"├─ Signature\n"
|
||||
"│ ├─ Tensor Type: FP16\n"
|
||||
"│ ├─ Memory Layout: GNHWC_GKYXC_GNHWK\n"
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"├─ Algorithm\n"
|
||||
"│ ├─ Thread block size: 256\n"
|
||||
"│ ├─ Data tile size: 256×256×32\n"
|
||||
"│ ├─ Gemm padding: DEFAULT\n"
|
||||
"│ ├─ Convolution specialization: DEFAULT\n"
|
||||
"│ ├─ Pipeline version: V4\n"
|
||||
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
"│ ├─ Warp Gemm parameters: \n"
|
||||
"│ │ ├─ subtile size: 16×16\n"
|
||||
"│ │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
"│ ├─ Memory access:\n"
|
||||
"│ │ ├─ A Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ ├─ B Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ └─ C Tile transfer: \n"
|
||||
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
"│ └─ \n"
|
||||
"└─ "));
|
||||
}
|
||||
|
||||
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
|
||||
// does not have a specialization for backward data convolutions. The test fails with:
|
||||
// "implicit instantiation of undefined template 'ck_tile::builder::ConvFactory<...>'"
|
||||
//
|
||||
// To enable this test, a ConvFactory specialization for backward data operations must be
|
||||
// implemented first.
|
||||
//
|
||||
// TEST(ConvDescriptionTest, BackwardDataInstanceHasDetailedDescription)
|
||||
// {
|
||||
// struct BackwardDataSignature
|
||||
// {
|
||||
// int spatial_dim = 2;
|
||||
// ckb::ConvDirection direction = ckb::ConvDirection::BACKWARD_DATA;
|
||||
// ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
// ckb::DataType data_type = ckb::DataType::FP16;
|
||||
// ckb::ElementwiseOperation elementwise_operation =
|
||||
// ckb::ElementwiseOperation::PASS_THROUGH; ckb::GroupConvDeviceOp device_operation =
|
||||
// ckb::BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
// };
|
||||
// static_assert(ckb::ConvSignatureDescriptor<BackwardDataSignature>);
|
||||
//
|
||||
// static constexpr const BackwardDataSignature SIGNATURE;
|
||||
// static constexpr const DefaultAlgorithm ALGORITHM;
|
||||
// using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
//
|
||||
// // Verify Brief works
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().brief(),
|
||||
// ckt::StringEqWithDiff("2D Backward Data convolution"));
|
||||
//
|
||||
// // Verify detailed works - to be updated once ConvFactory is implemented
|
||||
// EXPECT_THAT(ckr::Describe<Builder>().detailed(),
|
||||
// ckt::StringEqWithDiff("PLACEHOLDER"));
|
||||
// }
|
||||
} // namespace
|
||||
@@ -235,4 +235,149 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle()
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK()
|
||||
{
|
||||
// DL thread configuration
|
||||
constexpr DlThreadConfig DlThreadCfg{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
// DL thread cluster
|
||||
constexpr DlThreadCluster DlCluster{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
// DL A block transfer - K0_M0_M1_K1 format
|
||||
constexpr DlBlockTransferK0M0M1K1 DlBlockTransferA{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL B block transfer - K0_N0_N1_K1 format
|
||||
constexpr DlBlockTransferK0N0N1K1 DlBlockTransferB{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
// DL C thread transfer
|
||||
constexpr DlCThreadTransfer DlCTransfer{.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4};
|
||||
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.dl_thread_config = DlThreadCfg,
|
||||
.dl_thread_cluster = DlCluster,
|
||||
.dl_block_transfer_a = DlBlockTransferA,
|
||||
.dl_block_transfer_b = DlBlockTransferB,
|
||||
.dl_c_thread_transfer = DlCTransfer};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Test helper for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Note: Large_Tensor has identical parameters to regular XDL CShuffle
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor()
|
||||
{
|
||||
constexpr GridwiseXdlGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 2,
|
||||
.n_xdl_per_wave = 1};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 16,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 4},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.epilogue_c = {.m_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
// Large_Tensor uses the same descriptor as regular XDL CShuffle
|
||||
constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle FwdConvAlgorithm{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.fwd_specialization = FwdConvSpecialization,
|
||||
.gemm_specialization = GemmSpecialization::MNKPadding,
|
||||
.num_gemm_k_prefetch_stages = 1,
|
||||
.num_groups_to_merge = 1,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(
|
||||
kernel_string.starts_with("DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"));
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -727,7 +727,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
if constexpr(MPerBlock >= 64)
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -45,7 +46,28 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
|
||||
@@ -0,0 +1,891 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_m3_k;
|
||||
using Base::b_block_desc_n0_n1_n2_n3_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::APackedSize;
|
||||
using Base::BMmaKStride;
|
||||
using Base::BPackedSize;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::KXdlPack;
|
||||
using Base::MXdlPack;
|
||||
using Base::NXdlPack;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple5 = typename Base::Tuple5;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
|
||||
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
|
||||
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
|
||||
static constexpr auto async_vmcnt =
|
||||
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
|
||||
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves +
|
||||
num_buffer_load_a_scale + num_buffer_load_b_scale;
|
||||
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_bufs;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefetch 1, sync the async load
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(scale_mem_buf));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(scale_mem_buf));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(scale_mem_buf));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(
|
||||
scale_comp_buf)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(
|
||||
scale_comp_buf)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
|
||||
[scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0,
|
||||
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
|
||||
1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_bufs(I1));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
// constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
|
||||
});
|
||||
__builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
constexpr auto im_major = m0 / MXdlPack;
|
||||
constexpr auto im_minor = m0 % MXdlPack;
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
constexpr auto ik_major = k0 / KXdlPack;
|
||||
constexpr auto ik_minor = k0 % KXdlPack;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto in_major = n0 / NXdlPack;
|
||||
constexpr auto in_minor = n0 % NXdlPack;
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, I0, im_minor, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(in_major, I0, in_minor, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType, a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType, b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(im_major, in_major, im_minor, in_minor, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
|
||||
ik_minor * NXdlPack + in_minor>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -226,85 +226,197 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<BlockGemmPipelineSched
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
constexpr auto num_total_stages = std::max(2, MRepeat);
|
||||
if constexpr(num_total_stages > 2)
|
||||
{
|
||||
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
constexpr auto buffer_load_perstage_more =
|
||||
math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less =
|
||||
math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_stage2 =
|
||||
math::integer_divide_floor((num_buffer_load_stage2), 2);
|
||||
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
constexpr auto buffer_load_stages_more =
|
||||
num_buffer_load_stage1 -
|
||||
math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto buffer_load_issue_point_interval_stage2 =
|
||||
num_mfma_perstage / buffer_load_perstage_stage2;
|
||||
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
// Stage 1
|
||||
// global read more
|
||||
static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_more == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale +
|
||||
num_buffer_load_b_scale;
|
||||
constexpr auto num_dsread_a_mfma = math::integer_divide_ceil(
|
||||
num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma;
|
||||
|
||||
constexpr auto mfma_perstage_more =
|
||||
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
|
||||
constexpr auto mfma_perstage_less =
|
||||
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
|
||||
|
||||
constexpr auto mfma_stages_more =
|
||||
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
if constexpr(i < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) <
|
||||
mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
|
||||
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
|
||||
num_buffer_load_a_scale) < mfma_stages_more)
|
||||
{
|
||||
static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// global read less
|
||||
static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_less == 0)
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2, Sync
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
|
||||
@@ -122,7 +122,7 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
math::max(2, NXdlPerWave_),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
|
||||
@@ -429,8 +429,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(
|
||||
math::integer_divide_ceil(N0, NWave * NXdlPack), NWave, NXdlPack, K0, NkSwizzleNumber));
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
|
||||
@@ -48,28 +48,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
}
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
karg.p_max_token_id,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
@@ -1249,7 +1246,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
@@ -1279,7 +1275,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
// using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
|
||||
// NPerBlock>;
|
||||
|
||||
#if 0
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
@@ -1298,9 +1293,10 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
problem.MPadded,
|
||||
@@ -1317,29 +1313,41 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
problem.NPadded,
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
|
||||
make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock),
|
||||
// We pad the M unconditionaly for Scale
|
||||
const auto Padded_Scale_M =
|
||||
math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
|
||||
(KXdlPack * 64 / MPerXdl),
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a));
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / APackedSize)) *
|
||||
MPerXdl * MXdlPack / scale_pack_size_a,
|
||||
64 * KXdlPack * MXdlPack / scale_pack_size_a,
|
||||
1));
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(problem.N / (NXdlPack * NPerXdl),
|
||||
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
|
||||
(KXdlPack * 64 / NPerXdl),
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b));
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b),
|
||||
make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
|
||||
(ScaleBlockSize / BPackedSize)) *
|
||||
NPerXdl * NXdlPack / scale_pack_size_b,
|
||||
64 * KXdlPack * NXdlPack / scale_pack_size_b,
|
||||
1));
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
const index_t expert_id =
|
||||
__builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
|
||||
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr(NSwizzle)
|
||||
{
|
||||
@@ -1372,86 +1380,78 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
|
||||
constexpr auto AKThreads = AK0Threads * AK1Threads;
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
if constexpr(!IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K / APackedSize;
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
const index_t expert_scale_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize));
|
||||
const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
|
||||
problem.N * (IsInputGemm ? 2 : 1) *
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack);
|
||||
|
||||
// Gride buffer creation
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
|
||||
// A, B scale buffer
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid + expert_id * expert_scale_stride,
|
||||
p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
|
||||
|
||||
// A matrix blockwise direct to LDS copy
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
LDSTypeA,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
IndexType,
|
||||
1,
|
||||
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
gather_offsets);
|
||||
1>(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
gather_offsets);
|
||||
|
||||
// Thread-wise copy
|
||||
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
@@ -1463,7 +1463,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -1472,16 +1472,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared),
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize);
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
|
||||
|
||||
// Blockwise GEMM pipeline
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
@@ -1505,13 +1505,16 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
|
||||
|
||||
auto thread_offset_shuffled =
|
||||
get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
|
||||
|
||||
auto a_thread_offset_m = waveId_m;
|
||||
|
||||
// get each thread's offset int the scale tensor
|
||||
const index_t token_scale_pos = block_m_id * MPerBlock;
|
||||
if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
|
||||
auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
|
||||
AScaleDataType,
|
||||
AScaleDataType,
|
||||
@@ -1538,7 +1541,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
|
||||
Sequence<0, 1, 2>, // DimAccessOrder
|
||||
2, // SrcVectorDim
|
||||
KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
|
||||
1, // SrcScalarStrideInVector
|
||||
true>(b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
|
||||
@@ -1547,29 +1550,37 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
p_b_grid_up + expert_id * expert_stride,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride,
|
||||
auto b_blockwise_copy_up =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave / NXdlPack>{},
|
||||
I1,
|
||||
Number<NXdlPack>{},
|
||||
Number<KRepeat>{},
|
||||
Number<BK1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
const BScaleDataType* p_b_scale_grid_up =
|
||||
p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
|
||||
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
|
||||
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BScaleDataType,
|
||||
BScaleDataType,
|
||||
@@ -1587,25 +1598,30 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
thread_offset_shuffled / scale_pack_size_b));
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
// A
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
// Gate and Up
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
// C
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
// A scale
|
||||
a_scale_grid_desc_am_ak,
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
// B scale
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_thread_copy,
|
||||
b_scale_thread_copy_up,
|
||||
@@ -1616,23 +1632,23 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
else
|
||||
{
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_grid_desc_ak0_m_ak1, // A
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_grid_desc_bpreshuffled, // B
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_grid_desc_am_ak,
|
||||
c_thread_buf, // C
|
||||
a_scale_grid_desc_am_ak, // A scale
|
||||
a_scale_thread_copy,
|
||||
a_scale_grid_buf,
|
||||
b_scale_grid_desc_bn_ak,
|
||||
b_scale_grid_desc_bn_ak, // B scale
|
||||
b_scale_thread_copy,
|
||||
b_scale_grid_buf,
|
||||
num_k_block_main_loop);
|
||||
@@ -1643,84 +1659,101 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
|
||||
CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
|
||||
|
||||
// mul scales
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
|
||||
static_assert(M5 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
|
||||
static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
|
||||
static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock +
|
||||
m0 * M2 * M1 * M3 * M4 * M5 +
|
||||
m1 * M2 * M3 * M4 * M5 +
|
||||
imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
|
||||
topk_weights =
|
||||
*c_style_pointer_cast<const vector_type<float, M5>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
}
|
||||
static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation ==
|
||||
Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m5];
|
||||
up = up * topk_weights.AsType<float>()[m5];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
topk_weights.AsType<float>()[m5] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1738,19 +1771,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
|
||||
// shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M4,
|
||||
M5)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave)
|
||||
// per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2, // N2 = NXdlPack
|
||||
N3))), // N3 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
make_tuple(Sequence<>{},
|
||||
Sequence<0, 2, 4, 6, 7, 8>{},
|
||||
Sequence<>{},
|
||||
Sequence<1, 3, 5, 9>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -1762,8 +1801,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
@@ -1772,8 +1811,8 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
@@ -1781,36 +1820,39 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
N2,
|
||||
M3,
|
||||
I1,
|
||||
M5,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
m_thread_data_on_block_idx[I5],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
@@ -1859,7 +1901,7 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
using CDEBlockTransferCluster =
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
|
||||
constexpr index_t scatter_weight_idx = 1; // hack fix felix
|
||||
constexpr index_t scatter_weight_idx = 3; // hack fix felix
|
||||
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
@@ -1867,8 +1909,9 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray type
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
|
||||
// Sequence support
|
||||
// arbitray type
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
@@ -1898,13 +1941,25 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
|
||||
NXdlPerWave / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
|
||||
CShuffleNXdlPerWavePerShuffle / NXdlPack,
|
||||
1,
|
||||
1,
|
||||
MXdlPack,
|
||||
NXdlPack,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
@@ -1984,7 +2039,6 @@ struct GridwiseMoeGemmMX_BPreshuffle
|
||||
});
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
|
||||
@@ -74,6 +74,21 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
// Fixed values for Implicit GEMM
|
||||
struct FixedGemmParams
|
||||
{
|
||||
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool FixedVectorSize = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Persistent = false;
|
||||
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
};
|
||||
// Compile time parameters
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
@@ -82,31 +97,43 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Weight Gemm Layouts
|
||||
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdData =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdData,
|
||||
BsLayoutBwdData,
|
||||
CLayoutBwdData,
|
||||
NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdWeight,
|
||||
BsLayoutBwdWeight,
|
||||
CLayoutBwdWeight,
|
||||
NumWaveGroups>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user