mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4797 (commit 1a30400)
[CK_TILE] Add CK Tile bwd weight profiler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation To compare old CK and CK Tile, we need to extend the current CK profiler to support running also CK Tile instance with the same API. In order to have the same instance coverage in CK Tile compared to the old CK, I've added code generation from old CK configurations to CK Tile instances using the CK Builder. ## Technical Details - The codegen python script for CK Tile fwd convs is extended to support also bwd weight and bwd data. - The generated instances are added to the CMake build (target `device_grouped_conv_bwd_weight_tile_instance`s). - A new profiler op (`grouped_conv_bwd_weight_tile`) has been added to the CK Profiler.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc1e1a5155
commit
ae4e632c7d
@@ -18,9 +18,11 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_a_load_tr = std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
tensor_layout::gemm::ColumnMajor> &&
|
||||
!std::is_same_v<typename Problem::ADataType, float>;
|
||||
constexpr bool is_b_load_tr = std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
tensor_layout::gemm::RowMajor> &&
|
||||
!std::is_same_v<typename Problem::BDataType, float>;
|
||||
#else
|
||||
constexpr bool is_a_load_tr = false;
|
||||
constexpr bool is_b_load_tr = false;
|
||||
|
||||
@@ -53,7 +53,9 @@ struct GemmPipelineAgBgCrImplBase
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
@@ -65,7 +67,9 @@ struct GemmPipelineAgBgCrImplBase
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
|
||||
@@ -49,7 +49,9 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
@@ -65,7 +67,9 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
|
||||
@@ -38,6 +38,7 @@ template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Typ
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 8, false, false, false, EDouble> { using Type = WarpGemmMfmaF32F32F32M32N32K8<EDouble>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
|
||||
@@ -22,6 +22,15 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_INFO(std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType_>
|
||||
struct GroupedConvBwdWeightKernelArgs
|
||||
@@ -106,13 +115,18 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -192,13 +206,18 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
template <
|
||||
@@ -285,13 +304,18 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << ", k_batch: " << k_batch
|
||||
<< std::endl;
|
||||
}
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
@@ -474,12 +498,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
|
||||
std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
|
||||
std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
|
||||
}
|
||||
LogInfo("MPerBlock: ",
|
||||
number<TilePartitioner::MPerBlock>{},
|
||||
", NPerBlock: ",
|
||||
number<TilePartitioner::NPerBlock>{},
|
||||
", KPerBlock: ",
|
||||
number<TilePartitioner::KPerBlock>{});
|
||||
|
||||
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
@@ -517,11 +541,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
}
|
||||
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -533,12 +553,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
|
||||
if(kargs.k_batch > 128)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"For epilogue output data type that is not float/double, we must have "
|
||||
LogInfo("For epilogue output data type that is not float/double, we must have "
|
||||
"k_batch <= 128.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -548,20 +564,24 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for K_batch > 1!");
|
||||
}
|
||||
LogInfo("Conditions not met for K_batch > 1: VectorSizeC must be a multiple of 2 "
|
||||
"for fp16/bf16 when K_batch > 1.",
|
||||
"Now k_batch is ",
|
||||
kargs.k_batch,
|
||||
", VectorSizeC is ",
|
||||
GroupedConvTraitsType_::VectorSizeC);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
|
||||
}
|
||||
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
|
||||
kargs.GemmK,
|
||||
", BlockGemmShape K: ",
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
|
||||
", k_batch: ",
|
||||
kargs.k_batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -581,6 +601,17 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
LogInfo("For Filter1x1Stride1Pad0 specialization, all spatial dimensions must "
|
||||
"be 1, stride must be 1, and padding must be 0. Now for dimension ",
|
||||
i,
|
||||
": SpatialDim is ",
|
||||
SpatialDim,
|
||||
", ConvStride is ",
|
||||
ConvStride,
|
||||
", LeftPad is ",
|
||||
LeftPad,
|
||||
", RightPad is ",
|
||||
RightPad);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -596,6 +627,15 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
LogInfo("For Filter1x1Pad0 specialization, all spatial dimensions must be 1 "
|
||||
"and padding must be 0. Now for dimension ",
|
||||
i,
|
||||
": SpatialDim is ",
|
||||
SpatialDim,
|
||||
", LeftPad is ",
|
||||
LeftPad,
|
||||
", RightPad is ",
|
||||
RightPad);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -604,6 +644,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
LogInfo("For Filter3x3 specialization, ConvC must be 1. Now ConvC is ", ConvC);
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
@@ -612,6 +653,11 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
LogInfo("For Filter3x3 specialization, all spatial dimensions of the filter "
|
||||
"must be 3. Now for dimension ",
|
||||
i,
|
||||
", filter_spatial_dim is ",
|
||||
filter_spatial_dim);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -620,8 +666,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
|
||||
LogInfo("ExplicitGemm is only supported for Filter1x1Stride1Pad0 specialization.");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -633,14 +678,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
// Check access per C
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
|
||||
"input image!");
|
||||
LogInfo("Conv C is not a multiple of vector load size for input! ConvC: ",
|
||||
ConvC,
|
||||
", VectorSizeB: ",
|
||||
GroupedConvTraitsType_::VectorSizeB);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
LogInfo("Not supported input layout! Now InLayout is ", InLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -650,13 +697,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
LogInfo("Conv C is not a multiple of vector load size for weight! ConvC: ",
|
||||
ConvC,
|
||||
", VectorSizeC: ",
|
||||
GroupedConvTraitsType_::VectorSizeC);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
LogInfo("Not supported weight layout! Now WeiLayout is ", WeiLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -666,14 +716,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size "
|
||||
"for output image!");
|
||||
LogInfo("Conv K is not a multiple of vector load size for output! ConvK: ",
|
||||
ConvK,
|
||||
", VectorSizeA: ",
|
||||
GroupedConvTraitsType_::VectorSizeA);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
LogInfo("Not supported output layout! Now OutLayout is ", OutLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -682,7 +734,10 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
|
||||
LogInfo("Number of groups must be divisible by NumGroupsToMerge! ConvG: ",
|
||||
ConvG,
|
||||
", NumGroupsToMerge: ",
|
||||
GroupedConvTraitsType_::NumGroupsToMerge);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// UniversalGemm Policy
|
||||
struct GroupedConvUniversalPipelineAgBgCrPolicy
|
||||
: public UniversalGemmBasePolicy<GroupedConvUniversalPipelineAgBgCrPolicy>
|
||||
{
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create LDS block descriptor for B tensor.
|
||||
*
|
||||
* @tparam Problem Gemm pipeline problem.
|
||||
* @return B tensor LDS block descriptor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user