mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Add constexpr IsSupported
This commit is contained in:
@@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row>)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col>)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B
|
||||
if constexpr(is_same_v<BLayout, Row>)
|
||||
{
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Col>)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B1
|
||||
if constexpr(is_same_v<B1Layout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<B1Layout, Col>)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of C
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<CLayout, Col>)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw)
|
||||
{
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if (!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
arg.block_2_ctile_map_) and
|
||||
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemmSpec = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
@@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
C0MatrixMask c0_matrix_mask;
|
||||
typename GridwiseGemmSpec::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemmSpec::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
@@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
|
||||
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
|
||||
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
|
||||
block_2_ctile_map{GridwiseGemmSpec::MakeDefaultBlock2CTileMap(
|
||||
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
c_grid_desc_m_n)},
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemmSpec::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemmSpec::CalculateHasMainKBlockLoop(
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
c0_matrix_mask{c.GetLength(I1)},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
is_valid{GridwiseGemmSpec::CheckValidity(
|
||||
is_valid{GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n,
|
||||
block_2_ctile_map)}
|
||||
block_2_ctile_map) and
|
||||
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
const ADataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid)
|
||||
{
|
||||
assert(desc.is_valid and
|
||||
IsSupported(desc.a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
desc.b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
desc.a_grid_desc_ak0_m_ak1.GetLength(I0) * desc.a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
desc.b1_grid_desc_bk0_n_bk1.GetLength(I1)));
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemmSpec::GetSharedMemoryNumberOfByte()];
|
||||
assert(desc.is_valid);
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
AccElementwiseOperation acc_element_op{scale};
|
||||
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
Desc::GridwiseGemmSpec::template Run<true>(p_a_grid,
|
||||
Desc::GridwiseGemm::template Run<true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
@@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
Desc::GridwiseGemmSpec::template Run<false>(p_a_grid,
|
||||
Desc::GridwiseGemm::template Run<false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
|
||||
Reference in New Issue
Block a user