mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Merge branch 'ck_migraphx_integration' into codegen-enable-hiprtc
This commit is contained in:
@@ -615,96 +615,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool
|
||||
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row>)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col>)
|
||||
{
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B
|
||||
if constexpr(is_same_v<BLayout, Row>)
|
||||
{
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Col>)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B1
|
||||
if constexpr(is_same_v<B1Layout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<B1Layout, Col>)
|
||||
{
|
||||
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of C
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<CLayout, Col>)
|
||||
{
|
||||
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
@@ -861,268 +771,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return str.str();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
template <class AGridDescriptor>
|
||||
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
|
||||
{
|
||||
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class BGridDescriptor>
|
||||
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
|
||||
{
|
||||
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class B1GridDescriptor>
|
||||
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
|
||||
{
|
||||
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class CGridDescriptor>
|
||||
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
|
||||
{
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
C0MatrixMask c0_matrix_mask;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
bool is_valid = false;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
|
||||
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
|
||||
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
|
||||
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
|
||||
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
c0_matrix_mask{c.GetLength(I1)},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n,
|
||||
block_2_ctile_map) and
|
||||
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) *
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const { return is_valid; }
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
|
||||
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
|
||||
}
|
||||
|
||||
template <class Desc>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const float scale,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const ADataType* __restrict__ p_b_grid,
|
||||
const ADataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid)
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
assert(desc.is_valid);
|
||||
#endif
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
AccElementwiseOperation acc_element_op{scale};
|
||||
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include "ck/utility/statically_indexed_array.hpp"
|
||||
|
||||
#ifdef __HIPCC_RTC__
|
||||
@@ -204,7 +205,7 @@ struct scalar_type<bool>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
@@ -240,7 +241,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
|
||||
|
||||
__device__ int static err = 0;
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -300,7 +301,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -370,7 +371,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -452,7 +453,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 16, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -546,7 +547,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -650,7 +651,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 64, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -766,7 +767,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -892,7 +893,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -1042,7 +1043,7 @@ struct non_native_vector_base
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
@@ -1077,7 +1078,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 2, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1137,7 +1138,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 4, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1207,7 +1208,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 8, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1289,7 +1290,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 16, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1383,7 +1384,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 32, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1487,7 +1488,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 64, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
@@ -157,8 +157,11 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// workaround for ROCm 6.2 and later
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
|
||||
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
|
||||
|
||||
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
|
||||
|
||||
@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
|
||||
|
||||
@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
|
||||
|
||||
@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentK<Problem>();
|
||||
@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
|
||||
{
|
||||
@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
|
||||
{
|
||||
@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
|
||||
|
||||
static constexpr index_t WarpGemmM =
|
||||
@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// Compute
|
||||
static constexpr index_t Gemm0MFMA =
|
||||
kM0 * kN0 * kQKHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm1MFMA =
|
||||
kM0 * kN0 * kVHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kN0 * kVHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm3MFMA =
|
||||
kN0 * kQKHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
static constexpr index_t SGradT_LDS_READ_P2 =
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
// LDS Write
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// UniversalGemm Policy
|
||||
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
|
||||
struct UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr bool TransposeC = true;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(ADataType);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<K0 * MLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_kMLdsLayer_m_ak1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kM;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=mpair<=kN0
|
||||
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
|
||||
? M0
|
||||
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{},
|
||||
number<mpair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(BDataType);
|
||||
;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
number<K0 * NLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_kNLdsLayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kN;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=kN0
|
||||
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
index_t smem_size = 0;
|
||||
smem_size += smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user