mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
* something khushbu can help with * v1 v2 works with flatmm develop * v0 v1 v2 numerical error gone * Fixing numerical error, and interchange preshuffle configs to match with flatmm * Refactor GEMM pipeline configurations and integrate preshuffle support - Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3). - Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations. - Removed obsolete code and comments * clang format * fix vectorloadsize bug * add the Preshuffle3 * update kwarp calculation in gemm utils * update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch * fix: add CK_GFX950_SUPPORT macro for gfx950 detection * default disable rotating buffer * docs(CHANGELOG): update changelog for rocm 7.0 * Revert "docs(CHANGELOG): update changelog for rocm 7.0" This reverts commit2bc16fff84. * Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files. * revert example/ck_tile/flatmm to its original state * remove comment added by second author * switch to xor ALDSDescriptor * modify the MakeALdsDescriptor() * temporary profiling script * getting rid of line marker compiler error * UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy * add a minor fix for the config * typo fix * Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2 * revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp * revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp * reenable the GemmSpatiallyLocalTilePartitioner * make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline * remove hardcoded true for preshuffle bool template argument * rename script * remove gemm_profilie.sh script * merge conflict resolve * clang formatted * typo fix * Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp * Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy * Fix missing newline at end of file in run_gemm_example.inc * Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1 * addressing review comments * removing debug code * addressing review comments * Revert "addressing review comments" This reverts commit29c45192ba. * updating tile_engine code * addressing review comments --------- Co-authored-by: amd-khushbu <khuagarw@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
170 lines
5.7 KiB
C++
170 lines
5.7 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/common.hpp"
|
|
#include "ck_tile/host/concat.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/host/stream_utils.hpp"
|
|
#include "ck_tile/core/utility/env.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
|
#include "ck_tile/core/utility/type_traits.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
/// @brief The GEMM kernel host arguments.
|
|
///
|
|
/// @par Overview
|
|
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
|
/// object. It contain all necessary information required to build proper kernel argument
|
|
/// and launch kernel on GPU.
|
|
/// This structure defines the GEMM problem configuration by stating all required information
|
|
/// like M,N,K sizes and respective strides.
|
|
struct GemmHostArgs
|
|
{
|
|
CK_TILE_HOST GemmHostArgs() = default;
|
|
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
|
|
const void* b_ptr_,
|
|
void* e_ptr_,
|
|
index_t k_batch_,
|
|
index_t M_,
|
|
index_t N_,
|
|
index_t K_,
|
|
index_t stride_A_,
|
|
index_t stride_B_,
|
|
index_t stride_E_)
|
|
: a_ptr(a_ptr_),
|
|
b_ptr(b_ptr_),
|
|
e_ptr(e_ptr_),
|
|
M(M_),
|
|
N(N_),
|
|
K(K_),
|
|
stride_A(stride_A_),
|
|
stride_B(stride_B_),
|
|
stride_E(stride_E_),
|
|
k_batch(k_batch_)
|
|
{
|
|
}
|
|
|
|
const void* a_ptr;
|
|
const void* b_ptr;
|
|
union
|
|
{
|
|
void* e_ptr;
|
|
void* c_ptr;
|
|
};
|
|
|
|
index_t M;
|
|
index_t N;
|
|
index_t K;
|
|
index_t stride_A;
|
|
index_t stride_B;
|
|
|
|
union
|
|
{
|
|
index_t stride_E;
|
|
index_t stride_C;
|
|
};
|
|
|
|
index_t k_batch;
|
|
};
|
|
|
|
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
|
struct GemmKernel
|
|
{
|
|
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
|
/// functions.
|
|
using UniversalGemmKernel =
|
|
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
|
|
|
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
|
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
|
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
|
|
|
/// @brief Specify the layout configurations for A, B, E and D
|
|
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
|
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
|
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
|
|
|
/// @brief Specify the data type configurations for A, B, E and D
|
|
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
|
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
|
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
|
|
|
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
|
static_assert(
|
|
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
|
|
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
|
|
|
|
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
|
static_assert(
|
|
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
|
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
|
|
|
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
|
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
|
!is_detected<is_tuple, EDataType>::value,
|
|
"C/ELayout and C/EDataType must be scalars.");
|
|
|
|
static constexpr index_t NumATensor = 1;
|
|
static constexpr index_t NumBTensor = 1;
|
|
|
|
CK_TILE_HOST static auto GetName() -> const std::string
|
|
{
|
|
return UniversalGemmKernel::GetName();
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
|
{
|
|
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
|
}
|
|
|
|
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
|
{
|
|
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
|
{
|
|
return UniversalGemmKernel::BlockSize();
|
|
}
|
|
|
|
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) ->
|
|
typename UniversalGemmKernel::KernelArgs
|
|
{
|
|
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
|
/// matrices A, B.
|
|
return UniversalGemmKernel::MakeKernelArgs(
|
|
UniversalGemmHostArgs<NumATensor, NumBTensor /*NumDTensor = 0 */>(
|
|
{hostArgs.a_ptr},
|
|
{hostArgs.b_ptr},
|
|
{/*hostArgs.ds_ptr*/},
|
|
hostArgs.e_ptr,
|
|
hostArgs.k_batch,
|
|
hostArgs.M,
|
|
hostArgs.N,
|
|
hostArgs.K,
|
|
{hostArgs.stride_A},
|
|
{hostArgs.stride_B},
|
|
{/*hostArgs.stride_Ds*/},
|
|
hostArgs.stride_E));
|
|
}
|
|
|
|
CK_TILE_HOST static auto
|
|
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
|
{
|
|
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
|
}
|
|
|
|
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
|
{
|
|
UniversalGemmKernel{}.template operator()(kargs);
|
|
}
|
|
};
|
|
} // namespace ck_tile
|