mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'develop' into letaoqin/gemm_bias_activation
This commit is contained in:
@@ -132,7 +132,11 @@ if(GPU_ARCHS)
|
||||
unset(GPU_TARGETS CACHE)
|
||||
unset(AMDGPU_TARGETS CACHE)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS)
|
||||
set(USER_GPU_TARGETS 1)
|
||||
else()
|
||||
set(USER_GPU_TARGETS 0)
|
||||
endif()
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
@@ -162,7 +166,7 @@ endif()
|
||||
if(GPU_ARCHS)
|
||||
set(CK_GPU_TARGETS ${GPU_ARCHS})
|
||||
else()
|
||||
if(GPU_TARGETS)
|
||||
if(USER_GPU_TARGETS)
|
||||
set(CK_GPU_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
endif()
|
||||
@@ -545,7 +549,7 @@ ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS)
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
|
||||
10
Jenkinsfile
vendored
10
Jenkinsfile
vendored
@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
|
||||
def prefixpath = "/opt/rocm"
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -1138,7 +1138,7 @@ pipeline {
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
}
|
||||
steps{
|
||||
|
||||
@@ -91,6 +91,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
|
||||
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
|
||||
supported by the current compiler (this may take a long time).
|
||||
Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.
|
||||
|
||||
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
|
||||
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
|
||||
|
||||
@@ -12,12 +12,6 @@ API reference guide
|
||||
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
|
||||
some of the key design principles that are used to write new classes that extend CK functionality.
|
||||
|
||||
=================
|
||||
Using CK API
|
||||
=================
|
||||
|
||||
This section describes how to use the CK library API.
|
||||
|
||||
=================
|
||||
CK Datatypes
|
||||
=================
|
||||
|
||||
@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
// give a chance if stride is 0, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
|
||||
@@ -41,18 +41,39 @@ template <typename LayoutA,
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kTilePermute = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
|
||||
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
// Whether doing the CShuffle (transpose before the global memory), depending on the output
|
||||
// layout.
|
||||
constexpr bool CShuffleEpilogue =
|
||||
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadC>;
|
||||
using CodegenGemmTraits = ck_tile::
|
||||
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
invoke_gemm<ck_tile::half_t,
|
||||
matrix_a_layout,
|
||||
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
|
||||
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
|
||||
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
matrix_a_layout,
|
||||
matrix_b_layout,
|
||||
matrix_c_layout>(
|
||||
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
c_buf.FromDevice(c_host_gpu_ref.data());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "device_base.hpp"
|
||||
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) = 0;
|
||||
index_t StrideC) const = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
|
||||
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
[[maybe_unused]] index_t K,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) override
|
||||
index_t StrideC) const override
|
||||
{
|
||||
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
|
||||
}
|
||||
|
||||
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
|
||||
{
|
||||
const auto* parg = dynamic_cast<const Argument*>(base_arg);
|
||||
|
||||
if(!parg)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Provided argument pointer is not of an Argument class!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return GetWorkspaceSize(
|
||||
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -58,7 +58,7 @@ struct thread_buffer {
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
|
||||
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_n_k.mDesc.get_lengths()[0]
|
||||
: b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_m_k.mDesc.get_lengths()[1]
|
||||
: a_m_k.mDesc.get_lengths()[0];
|
||||
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_element_op(a_m_k(m, k))
|
||||
: a_element_op(a_m_k(k, m));
|
||||
BDataType v_b = b_element_op(b_n_k(n, k));
|
||||
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_element_op(b_n_k(n, k))
|
||||
: b_element_op(b_n_k(k, n));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? c_m_n(m, n)
|
||||
: c_m_n(n, m);
|
||||
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
acc += static_cast<AccDataType>(A[row * strideA + k]) *
|
||||
static_cast<AccDataType>(B[col * strideB + k]);
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
|
||||
}
|
||||
|
||||
C[row * strideC + col] = acc; // Store as AccDataType
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
|
||||
errC = hipMemcpy(
|
||||
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
|
||||
|
||||
@@ -3,5 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
171
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
171
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#define CK_TILE_MAX_RANK 5
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
|
||||
// memory.
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kTilePermute_,
|
||||
index_t kRank_,
|
||||
index_t kPerm0,
|
||||
index_t kPerm1,
|
||||
index_t TileSize0,
|
||||
index_t TileSize1,
|
||||
index_t kPerm2 = 0,
|
||||
index_t kPerm3 = 0,
|
||||
index_t kPerm4 = 0,
|
||||
index_t TileSize2 = 0,
|
||||
index_t TileSize3 = 0,
|
||||
index_t TileSize4 = 0>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTilePermute = kTilePermute_;
|
||||
static constexpr index_t kRank = kRank_;
|
||||
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
|
||||
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
|
||||
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
const index_t* kPerm = Problem::kPerm;
|
||||
static constexpr bool kTilePermute = Problem::kTilePermute;
|
||||
static constexpr index_t kRank = Problem::kRank;
|
||||
const index_t* tile_sizes = Problem::tile_sizes;
|
||||
|
||||
// No additional shared memory needed
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
|
||||
{
|
||||
using DataType = typename OAccTile::DataType;
|
||||
|
||||
// Get thread buffer
|
||||
auto& thread_buf = o_acc_tile.get_thread_buffer();
|
||||
|
||||
// Create a temporary buffer to hold the permuted data
|
||||
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
|
||||
|
||||
// Get the lengths of each dimension
|
||||
auto thread_tensor_lengths = o_acc_tile.get_lengths();
|
||||
|
||||
// Total number of elements
|
||||
index_t total_elements = OAccTile::kThreadElementSpaceSize;
|
||||
|
||||
// Iterate over all elements
|
||||
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
|
||||
{
|
||||
// Convert linear index to multi-dimensional indices
|
||||
array<index_t, kRank> indices;
|
||||
index_t remaining = linear_idx;
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
|
||||
remaining /= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Apply the permutation
|
||||
array<index_t, kRank> permuted_indices;
|
||||
static_for<0, kRank, 1>{}(
|
||||
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
|
||||
|
||||
// Compute offsets
|
||||
index_t dst_offset = 0;
|
||||
index_t stride = 1;
|
||||
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
dst_offset += permuted_indices[rev_i] * stride;
|
||||
stride *= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Move the data
|
||||
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
|
||||
}
|
||||
|
||||
// Copy the permuted data back to the original thread buffer
|
||||
for(index_t i = 0; i < total_elements; ++i)
|
||||
{
|
||||
thread_buf.set_as(i, permuted_thread_buf.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
|
||||
{
|
||||
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
|
||||
|
||||
// Compute the tile coordinates by dividing the window origin by the tile sizes
|
||||
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
|
||||
// printf("The tile_coord is: %d", tile_coords[i]);
|
||||
}
|
||||
|
||||
// Apply the permutation to the tile coordinates
|
||||
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_tile_coords[i] = tile_coords[kPerm[i]];
|
||||
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
|
||||
}
|
||||
|
||||
// Compute the permuted window origin
|
||||
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
|
||||
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
|
||||
}
|
||||
|
||||
typename ODramWindowTmp::BottomTensorIndex step = {};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
step[i] = permuted_window_origin[i] - current_window_origin[i];
|
||||
}
|
||||
|
||||
// Move the window
|
||||
move_tile_window(o_dram_window_tmp, step);
|
||||
|
||||
// Permute the data within the tile if necessary
|
||||
if constexpr(kTilePermute)
|
||||
{
|
||||
permute_tile_data(o_acc_tile);
|
||||
}
|
||||
|
||||
// Store the tile data to the permuted location
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
@@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::QDataType,
|
||||
@@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadHeadDimV,
|
||||
Problem::kPadHeadDimV,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -84,21 +97,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::OGradDataType,
|
||||
@@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
Problem::kPadSeqLenK,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -149,21 +174,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadHeadDimQ,
|
||||
Problem::kPadSeqLenK,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -181,7 +212,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// these are for global load
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
@@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
using GemmProblem =
|
||||
GemmPipelineProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
|
||||
TileGemmTraits<Problem::kPadSeqLenQ,
|
||||
Problem::kPadSeqLenK,
|
||||
Problem::kPadHeadDimQ,
|
||||
typename tensor_layout::gemm::RowMajor,
|
||||
typename tensor_layout::gemm::ColumnMajor,
|
||||
typename tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,12 +23,13 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
@@ -11,20 +11,12 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -32,6 +24,10 @@ struct GemmKernel
|
||||
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
|
||||
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
|
||||
{
|
||||
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
|
||||
@@ -184,6 +180,7 @@ struct GemmKernel
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4,15 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename Problem::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename Problem::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename Problem::LayoutC>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
#if 0
|
||||
// 2d
|
||||
@@ -4,15 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -7,12 +7,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV2
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
|
||||
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
|
||||
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy =
|
||||
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
|
||||
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
|
||||
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,20 +13,23 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
bool kPadA_ = false,
|
||||
bool kPadB_ = false,
|
||||
bool kPadC_ = false>
|
||||
struct BlockGemmPipelineProblem
|
||||
typename TileGemmTraits_>
|
||||
struct GemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
static constexpr bool kPadA = GemmTraits::kPadA;
|
||||
static constexpr bool kPadB = GemmTraits::kPadB;
|
||||
static constexpr bool kPadC = GemmTraits::kPadC;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename GemmTraits::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename GemmTraits::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename GemmTraits::LayoutC>;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
|
||||
27
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
27
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadA_,
|
||||
bool kPadB_,
|
||||
bool kPadC_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
using LayoutA = LayoutA_;
|
||||
using LayoutB = LayoutB_;
|
||||
using LayoutC = LayoutC_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user