// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_tile/host.hpp" #include "mx_flatmm.hpp" template using is_row_major_t = ck_tile::bool_constant< std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>; template float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::stream_config& s) { using FlatmmConfig = typename MXFlatmmArchTraits::Config; using FlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile::sequence>; using MXGemmTraits = ck_tile::TileGemmUniversalTraits; using ComputeDataType = ADataType; static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); constexpr auto scheduler = FlatmmConfig::Scheduler; ck_tile::ignore = Splitk; // determined by scale shuffle pattern constexpr int BlockedXDLN_PerWarp = MXFlatmmArchTraits::BlockedXDLN_PerWarp; using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem; using MXFlatmmPipeline = typename MXFlatmmArchTraits::template MXFlatmmPipeline; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; using GemmEpilogue = std::conditional_t< FlatmmConfig::TiledMMAPermuteN, ck_tile::PermuteNEpilogue>, // VectorSizeC ck_tile::CShuffleEpilogue>>; using Kernel = ck_tile::MXFlatmmKernel; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(kargs); const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n" << "Shape: " << FlatmmShape::GetName() << "\n" << "problem: " << MXPipelineProblem::GetName() << "\n" << "pipeline: " << MXFlatmmPipeline::GetName() << "\n" << "epilogue: " << GemmEpilogue::GetName() << "\n" << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } // Declare rotating_mem_ptr here so it stays in scope until it is needed std::unique_ptr> rotating_mem_ptr; std::function preprocess; auto clear_gemm_output = [&]() { if(args.k_batch > 1) hipGetErrorString( hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; if(s.flush_cache_) { std::cout << "Flushing cache..." << std::endl; constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits::PackedSize; constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits::PackedSize; ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major_t{})); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( args.K, args.N, args.stride_B, is_row_major_t{})); auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; rotating_mem_ptr = std::make_unique>( kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); rotating_mem_ptr->Print(); preprocess = [&]() { ck_tile::flush_icache(); rotating_mem_ptr->Next(); clear_gemm_output(); }; } else { preprocess = clear_gemm_output; } return ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }