mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK TILE] GEMM with packed i4 (#1885)
* [CK TILE] GEMM with packed i4 * Fixes * fixes * fixes * fixes
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
@@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
|
||||
using BWarpTile =
|
||||
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
|
||||
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
AWarpTileDistr{}))>;
|
||||
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
@@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
public:
|
||||
using Traits = GemmTraits_<Problem_, Policy_>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
|
||||
WarpTile& warp_tile)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size =
|
||||
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
@@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
});
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
@@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
AWarpTensor a_warp_tile;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
BWarpTensor b_warp_tile;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
@@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
@@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
|
||||
@@ -60,6 +60,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
@@ -139,12 +146,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
|
||||
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
|
||||
|
||||
@@ -21,6 +21,13 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
@@ -33,9 +40,11 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
|
||||
MinMemInFlyBytes / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
static constexpr index_t FullMemBandPrefetchStages =
|
||||
integer_divide_ceil(MinMemInFlyBytes / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) / APackedSize +
|
||||
NPerBlock * sizeof(BDataType) / BPackedSize) *
|
||||
KPerBlock);
|
||||
static constexpr index_t PrefetchStages =
|
||||
FullMemBandPrefetchStages >= 2
|
||||
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
|
||||
|
||||
@@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
|
||||
constexpr index_t smem_size_a =
|
||||
sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
|
||||
constexpr index_t smem_size_b =
|
||||
sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
@@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
|
||||
@@ -20,6 +20,11 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
@@ -37,13 +42,15 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
16) *
|
||||
return integer_divide_ceil(sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>()
|
||||
.get_element_space_size() /
|
||||
APackedSize,
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
@@ -75,7 +82,8 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
integer_divide_ceil(
|
||||
sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
|
||||
@@ -13,14 +13,16 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_>
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
@@ -53,13 +55,15 @@ struct GemmPipelineProblemBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(ADataType);
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -69,17 +73,19 @@ struct GemmPipelineProblemBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(BDataType);
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(BDataType);
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,9 +149,14 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_>
|
||||
using GemmPipelineProblem =
|
||||
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -154,14 +165,16 @@ template <typename ADataType_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
|
||||
@@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (16 / sizeof(DataType)) == 0)
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (16 / sizeof(DataType));
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (8 / sizeof(DataType)) == 0)
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (8 / sizeof(DataType));
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (4 / sizeof(DataType)) == 0)
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (4 / sizeof(DataType));
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (2 / sizeof(DataType)) == 0)
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (2 / sizeof(DataType));
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
|
||||
Reference in New Issue
Block a user