mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Add host lib (#1134)
* Format
* Format
* Format
* Remove const
* Use the right template
* Format
* Format
* add row/col instances
* Add missing file
* fixed
* Format
* Updates
* Format
* fixed rrr layout
* Format
* Update test and embed modules
* Restore older version
* Update year
* Set -fPIC
* Format
* Use double for isnan
* rename host folder to codegen + minor fix
* add codegen CI test
* add option to build components without building CK
* fix the groovy syntax
* fix typo
* use the correct function for the codegen stage
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 8eff4d62b6]
This commit is contained in:
@@ -498,6 +498,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -505,87 +585,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
|
||||
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
static constexpr auto ds_tuple()
|
||||
{
|
||||
return transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
DsDesc{});
|
||||
}
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_tuple()))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k;
|
||||
BGridDesc_N_K b_grid_desc_n_k;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
EGridDesc_M_N e_grid_desc_m_n;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// for checking vector load/store
|
||||
index_t MRaw;
|
||||
index_t NRaw;
|
||||
index_t KRaw;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_)
|
||||
: a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
|
||||
b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
|
||||
ds_grid_desc_m_n{transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
ds)},
|
||||
e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
|
||||
a_grid_desc_ak0_m_ak1{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
|
||||
b_grid_desc_bk0_n_bk1{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
transform_tuples(
|
||||
[&](auto d) constexpr {
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
|
||||
},
|
||||
ds))},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n)},
|
||||
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
MRaw{e.GetLength(I0)},
|
||||
NRaw{e.GetLength(I1)},
|
||||
KRaw{a.GetLength(I1)}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map) and
|
||||
IsSupported(MRaw, NRaw, KRaw);
|
||||
}
|
||||
|
||||
constexpr index_t GetBlockSize() const { return BlockSize; }
|
||||
|
||||
constexpr index_t GetGridSize() const
|
||||
{
|
||||
return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
|
||||
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
|
||||
}
|
||||
|
||||
template <class Desc, class DsPointer>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid)
|
||||
{
|
||||
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
assert(desc.IsValid());
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 1)
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 1)
|
||||
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
|
||||
{
|
||||
}
|
||||
@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
@@ -142,8 +143,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: BlockToCTileMap_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
@@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -237,8 +239,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
@@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M01_;
|
||||
@@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
@@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const BGridDesc_N_K& b_grid_desc_n_k,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
const Block2ETileMap&)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// check block-to-E-tile
|
||||
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
|
||||
//{
|
||||
// return false;
|
||||
//}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
// check tensor size: cannot be larger than 2GB each
|
||||
|
||||
Reference in New Issue
Block a user