custom vector size

This commit is contained in:
Bartlomiej Kocot
2025-05-13 14:38:36 +00:00
parent aaca2e2b08
commit 976e32ccfa
5 changed files with 77 additions and 23 deletions

View File

@@ -38,6 +38,10 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args,
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -47,9 +51,16 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args,
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using CodegenTraits = ck_tile::GroupedConvImplicitGemmTraits;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<InDataType, WeiDataType, AccDataType, CodegenShape, CodegenTraits>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
CodegenTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
@@ -69,7 +80,9 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
memory_operation,
true,
VectorSizeC>>;
constexpr auto ConvSpec = ck_tile::ConvolutionForwardSpecialization::Default;
@@ -99,7 +112,10 @@ float grouped_conv_fwd_calc(const ck_tile::GroupedConvHostArgs& args,
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
<< '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel(

View File

@@ -23,7 +23,9 @@ template <typename ADataType_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_>
memory_operation_enum MemoryOperation_,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -41,6 +43,8 @@ struct CShuffleEpilogueProblem
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
};
template <typename Problem_, typename Policy_ = void>
@@ -65,6 +69,8 @@ struct CShuffleEpilogue
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
@@ -91,8 +97,12 @@ struct CShuffleEpilogue
*/
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
return static_cast<index_t>(MaxVectorStoreSize / sizeof(ODataType));
}
template <typename Problem>

View File

@@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = Problem::VectorSizeA;
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
@@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N1 = Problem::VectorSizeB;
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);

View File

@@ -14,7 +14,10 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct GemmPipelineProblemBase
{
using Traits = remove_cvref_t<Traits_>;
@@ -24,6 +27,8 @@ struct GemmPipelineProblemBase
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
@@ -114,7 +119,11 @@ struct GemmPipelineProblemBase
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
if constexpr(FixedVectorSize)
{
return VectorSizeA_;
}
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
@@ -125,7 +134,11 @@ struct GemmPipelineProblemBase
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
if constexpr(FixedVectorSize)
{
return VectorSizeB_;
}
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
@@ -152,13 +165,19 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
ComputeDataType_>;
ComputeDataType_,
FixedVectorSize_,
VectorSizeA_,
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
@@ -168,7 +187,10 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
typename ComputeDataType_ = ADataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -178,6 +200,10 @@ struct UniversalGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr bool VectorSizeA = VectorSizeA_;
static constexpr bool VectorSizeB = VectorSizeB_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;

View File

@@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -458,10 +459,11 @@ struct UniversalGemmBasePolicy
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)