mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Ck tile gemm example (#1488)
* Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout * Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future. * Fix: Clang Format, API fixed from fmha * fix with better naming convention * revert back the pipeline code of fmha * Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one. * clang format with the reference_gemm file * convert the clang format with the remod.py * Changed the format and variable name of the kernel gemm_shape and partitioner --------- Co-authored-by: thomasning <thomasning@banff-cyxtera-s70-4.ctr.dcgpu>
This commit is contained in:
@@ -4,7 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
|
||||
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
|
||||
using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1<
|
||||
BlockGemmProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
|
||||
BlockGemmARegBSmemCRegV1DefaultPolicy>;
|
||||
BlockGemmARegBGmemCRegV1DefaultPolicy>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
@@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// block GEMM
|
||||
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
|
||||
BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
@@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// block GEMM
|
||||
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window);
|
||||
return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
176
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Normal file
176
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Normal file
@@ -0,0 +1,176 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
struct GemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::KernelBlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
|
||||
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
|
||||
{
|
||||
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct GemmCommonKargs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
|
||||
float epsilon;
|
||||
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
ck_tile::index_t K;
|
||||
ck_tile::index_t stride_A;
|
||||
ck_tile::index_t stride_B;
|
||||
ck_tile::index_t stride_C;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr,
|
||||
float epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
|
||||
{
|
||||
const index_t i_m = TilePartitioner::iM;
|
||||
const index_t i_n = TilePartitioner::iN;
|
||||
// options
|
||||
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
// Convert pointers to tensor views
|
||||
auto a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<GemmPipeline::AlignmentA>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::AlignmentA>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<GemmPipeline::AlignmentB>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{ // Default NK layout
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::AlignmentB>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
{i_n, 0});
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
|
||||
|
||||
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
|
||||
|
||||
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
|
||||
|
||||
auto c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<GemmPipeline::AlignmentC>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<GemmPipeline::AlignmentC>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto CBlockWindow = make_tile_window(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
// epilogue.
|
||||
EpiloguePipeline{}(CBlockWindow, acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
38
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
Normal file
38
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename BlockGemmShape_>
|
||||
struct GemmTilePartitioner
|
||||
{
|
||||
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
|
||||
static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
|
||||
static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(i_tile_m * kM);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(i_tile_n * kN);
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
|
||||
{
|
||||
ck_tile::index_t GridDimX = (M + kM - 1) / kM;
|
||||
ck_tile::index_t GridDimY = (N + kN - 1) / kN;
|
||||
ck_tile::index_t GridDimZ = batch_size;
|
||||
return dim3(GridDimX, GridDimY, GridDimZ);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()()
|
||||
{
|
||||
const index_t i_GridDimX = blockIdx.x;
|
||||
const index_t i_GridDimY = blockIdx.y;
|
||||
const index_t i_GridDimZ = blockIdx.z;
|
||||
return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,12 +19,16 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t AlignmentA = Problem::AlignmentA;
|
||||
static constexpr index_t AlignmentB = Problem::AlignmentB;
|
||||
static constexpr index_t AlignmentC = Problem::AlignmentC;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
@@ -35,6 +40,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
@@ -140,8 +150,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
|
||||
do
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -167,8 +176,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
|
||||
@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
index_t smem_size = 0;
|
||||
smem_size += smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
#elif 1
|
||||
// fake XOR
|
||||
template <typename Problem>
|
||||
@@ -168,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -177,7 +204,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = KernelBlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -188,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kMPerBlock / (M2 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -206,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -215,7 +244,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = KernelBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -226,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = KernelBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = kNPerBlock / (N2 * N0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -5,13 +5,17 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#define VectorLoadSize 16
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
typename BlockGemmShape_,
|
||||
bool kPadA_ = false,
|
||||
bool kPadB_ = false,
|
||||
bool kPadC_ = false>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
|
||||
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,18 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
|
||||
struct TileGemmShape
|
||||
{
|
||||
static constexpr index_t kM = kMPerTile;
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
using WarpTile = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user