mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline
This commit is contained in:
@@ -27,7 +27,9 @@ template <typename ADataType_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1>
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -48,6 +50,8 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -80,6 +84,8 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
@@ -98,6 +104,10 @@ struct CShuffleEpilogue
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
|
||||
@@ -75,7 +75,6 @@ struct Default2DEpilogue
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
@@ -101,6 +100,15 @@ struct Default2DEpilogue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr)
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -114,6 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
@@ -181,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
357
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Normal file → Executable file
357
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Normal file → Executable file
@@ -12,47 +12,75 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct FlatmmProblem
|
||||
{
|
||||
CK_TILE_HOST FlatmmProblem() = default;
|
||||
CK_TILE_HOST FlatmmProblem(
|
||||
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
||||
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
struct FlatmmHostArgs : public FlatmmProblem
|
||||
template <index_t NumDTensor = 0>
|
||||
struct FlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST FlatmmHostArgs() = default;
|
||||
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_shuffle_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_)
|
||||
: FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
|
||||
a_ptr(a_ptr_),
|
||||
b_shuffle_ptr(b_shuffle_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_shuffle_ptr;
|
||||
void* c_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct FlatmmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
// const void* b_shuffle_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* e_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
@@ -63,23 +91,29 @@ struct FlatmmKernel
|
||||
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -95,32 +129,21 @@ struct FlatmmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
struct FlatmmKernelArgs
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_shuffle_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs)
|
||||
{
|
||||
return FlatmmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_shuffle_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch};
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -130,8 +153,7 @@ struct FlatmmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
@@ -170,10 +192,10 @@ struct FlatmmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -244,7 +266,45 @@ struct FlatmmKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
bool DTesnorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, ELayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
||||
{
|
||||
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
||||
"NPerBlock without padding!");
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
|
||||
{
|
||||
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
||||
"MPerBlock without padding!");
|
||||
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
|
||||
{
|
||||
@@ -274,15 +334,17 @@ struct FlatmmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
CDataType* c_ptr,
|
||||
const FlatmmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -317,29 +379,54 @@ struct FlatmmKernel
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_ptr,
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -365,26 +452,47 @@ struct FlatmmKernel
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -393,7 +501,8 @@ struct FlatmmKernel
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& c_pad_view = views.at(I2);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -416,21 +525,43 @@ struct FlatmmKernel
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(idxN)), 0});
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const FlatmmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -438,7 +569,7 @@ struct FlatmmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -450,15 +581,18 @@ struct FlatmmKernel
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
@@ -468,18 +602,27 @@ struct FlatmmKernel
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_shuffle_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -9,9 +9,33 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem>
|
||||
struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseFlatmmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
@@ -33,39 +57,44 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -75,7 +104,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x32) || defined(USING_MFMA_32x32x16)
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -91,64 +119,68 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
#endif
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
#endif
|
||||
if constexpr(WG::kM == 16 && WG::kN == 16)
|
||||
{
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
else if constexpr(WG::kM == 32 && WG::kN == 32 &&
|
||||
(A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
@@ -159,13 +191,19 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BFlatBlockWindowTmp::DataType>>,
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -19,55 +20,61 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
#if defined(USING_MFMA_16x16x32)
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
@@ -116,6 +123,95 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum global memory vector load size.
|
||||
*
|
||||
* @tparam Problem The UniversalGemmPipelineProblem object.
|
||||
* @tparam DataType The tensor data type we're considering.
|
||||
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
|
||||
* @tparam XPerTile The contiguous Tile dimension size.
|
||||
* @return Maximum DRAM vector load size.
|
||||
*/
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
@@ -138,6 +234,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -189,7 +300,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
@@ -232,23 +343,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad =
|
||||
Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(TileShape::idxN); // N_Warp
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
@@ -318,23 +427,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy =
|
||||
BlockFlatmmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename
|
||||
// Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
|
||||
@@ -35,6 +36,8 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
@@ -401,6 +401,196 @@ struct SimplifiedGenericAttentionMask
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl {
|
||||
template <bool IsMasking_> struct SimplifiedRatioMaskName;
|
||||
template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
|
||||
template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
// this version is used for cases that the step length of y-direction changes greater than one. It
|
||||
// means that the mask is not a regular triangular matrix.
|
||||
|
||||
// clang-format off
|
||||
/* y_ratio is used to describe the step length of y-direction changes
|
||||
in certain performance optimization scenarios like merging seqlen
|
||||
and qk_head_ratio, for example:
|
||||
|
||||
x=1/y=6/y_ratio=2(top-left)
|
||||
1 * * * * * * *
|
||||
1 * * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 1 * * * * *
|
||||
1 1 1 * * * * *
|
||||
|
||||
*/
|
||||
// clang-format on
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedRatioAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
|
||||
static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedRatioAttentionMask(
|
||||
index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
|
||||
: SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
|
||||
/*x_=*/x_,
|
||||
/*y_total_=*/y_total_,
|
||||
/*x_total_=*/x_total_,
|
||||
/*y_real_=*/y_real_,
|
||||
/*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
|
||||
/*y_ratio_mdiv_=*/y_ratio_mdiv_)
|
||||
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedRatioAttentionMask(index_t y_,
|
||||
index_t x_,
|
||||
index_t y_total_,
|
||||
index_t x_total_,
|
||||
index_t y_real_,
|
||||
index_t y_ratio_,
|
||||
mdiv y_ratio_mdiv_)
|
||||
: y(y_),
|
||||
x(x_),
|
||||
y_total(y_total_),
|
||||
x_total(x_total_),
|
||||
y_real(y_real_),
|
||||
y_ratio(y_ratio_),
|
||||
y_ratio_mdiv(y_ratio_mdiv_)
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = -y_real +
|
||||
static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
|
||||
1;
|
||||
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
uint32_t y_offset = i_y + YTile - 1;
|
||||
index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
|
||||
index_t x_start = -y_real + x_tmp + 1;
|
||||
index_t x_end = min(x_tmp + x,
|
||||
x_total); // need min in case x is padded
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
|
||||
|
||||
return (i_x + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_x_end = i_x + TileWidth;
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
uint32_t y_tmp = static_cast<uint32_t>(i_y);
|
||||
bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
|
||||
x_total); // consider right pad
|
||||
bool bottom_left_edge =
|
||||
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
|
||||
return top_right_edge || bottom_left_edge;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
// y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
|
||||
index_t y_real;
|
||||
index_t y_ratio;
|
||||
mdiv y_ratio_mdiv;
|
||||
};
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
|
||||
@@ -51,6 +51,27 @@ struct TrivialPageBlockNavigator
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t /*id*/) const
|
||||
{
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t /*block_index*/,
|
||||
TileWindow /*tile_window*/,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
@@ -153,6 +174,56 @@ struct PageBlockNavigator
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t id) const
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
if(id >= 0)
|
||||
tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
|
||||
fixed_offset);
|
||||
else
|
||||
tile_window.set_bottom_tensor_view_data_ptr(nullptr);
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
auto local_tile_window = tile_window; // not affect origin window
|
||||
ck_tile::move_tile_window(local_tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, local_tile_window.get_window_origin());
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
|
||||
if(new_block_index < num_blocks)
|
||||
{
|
||||
return physical_block_indices[new_block_index];
|
||||
}
|
||||
else
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
|
||||
{
|
||||
return block_index == num_blocks - 1;
|
||||
|
||||
@@ -316,56 +316,56 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -468,51 +468,51 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
|
||||
@@ -808,6 +808,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
@@ -847,6 +848,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
@@ -889,6 +891,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
@@ -928,6 +931,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
|
||||
1376
include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp
Normal file
1376
include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -122,6 +122,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
// use larger K/V LDS buffer size will lower the occupancy
|
||||
else if constexpr(64 <= kK0 || 64 <= kK1)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,769 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: This class is a variant of the existing BlockFmhaFwdSplitKVPipelineQRKSVS pipeline.
|
||||
// Refactoring to extract shared logic is recommended as future work.
|
||||
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
!kHasLogitsSoftCap)) ||
|
||||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_pagedkv";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
auto q = load_tile(q_dram_window);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
// k_dram_block_window
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
// divisible by kN0)
|
||||
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start =
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
// v_dram_window
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc); // initialize C
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
auto physical_next_block_id_k =
|
||||
__builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
x += type_convert<SaccDataType>(bias_element_func(y));
|
||||
#else
|
||||
x += log2e_v<SaccDataType> *
|
||||
type_convert<SaccDataType>(bias_element_func(y));
|
||||
#endif
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
// position_encoding accept only logical coordinates, do conversion here
|
||||
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
auto apply_logits_transform =
|
||||
[&variant, &variant_params, &block_indices](auto& x) {
|
||||
x = variant.LogitsTransform(variant_params,
|
||||
variant.QueryTransform(variant_params, x),
|
||||
block_indices.batch_idx,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout(apply_logits_transform, s_acc);
|
||||
#else
|
||||
tile_elementwise_inout(apply_logits_transform, s_acc);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
i_page_block_k, k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
// check columns in [aligned_physical_seqlen_k_start, physical_seqlen_k_end)
|
||||
if(kv_l2p_offset > 0)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&, physical_seqlen_k_start_ = physical_seqlen_k_start](auto tile_idx) {
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return col < physical_seqlen_k_start_;
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
// mask accept only logical coordinates, do conversion here
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}) - kv_l2p_offset,
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
}
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
return exp2(scale_s * m_old[i_idx] - row_max);
|
||||
}
|
||||
}
|
||||
}();
|
||||
#else
|
||||
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// FIXME: this use different equation from FA v2 paper,
|
||||
// but produce correc result.
|
||||
// Is the equation wrong?
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&,
|
||||
&i_page_block_v_ = i_page_block_v,
|
||||
&v_dram_window_ = v_dram_window](auto i_k1) {
|
||||
auto physical_next_block_id_v_ =
|
||||
__builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}));
|
||||
const auto v = load_tile(v_dram_window_); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
i_page_block_v_ = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kHasLogitsSoftCap)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
|
||||
#endif
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowLengths,
|
||||
typename KPageBlockNavigator,
|
||||
typename VDramBlockWindowLengths,
|
||||
typename VPageBlockNavigator,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
|
||||
const KPageBlockNavigator& k_page_block_navigator,
|
||||
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
|
||||
const VPageBlockNavigator& v_page_block_navigator,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{},
|
||||
v_dram_block_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
{
|
||||
if constexpr(128 >= Problem::BlockFmhaShape::kK0)
|
||||
return BlockGemmARegBSmemCRegV2R1<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -320,6 +320,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
auto physical_next_block_id_k =
|
||||
__builtin_amdgcn_readfirstlane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
auto physical_next_block_id_v = __builtin_amdgcn_readfirstlane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -600,8 +605,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
i_page_block_v =
|
||||
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
@@ -612,6 +617,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static_for<0, k1_loops - 1, 1>{}([&,
|
||||
&i_page_block_v_ = i_page_block_v,
|
||||
&v_dram_window_ = v_dram_window](auto i_k1) {
|
||||
auto physical_next_block_id_v_ =
|
||||
__builtin_amdgcn_readfirstlane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}));
|
||||
const auto v = load_tile(v_dram_window_); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
@@ -634,12 +642,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
i_page_block_v_ = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v_, v_dram_window_, {0, kK1});
|
||||
i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -61,6 +61,58 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename LSEDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdPagedKVPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
|
||||
@@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -54,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -127,7 +129,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -146,6 +150,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& /* unused */,
|
||||
const AttentionVariantParams& /* unused */,
|
||||
const BlockIndices& /* unused */,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -890,7 +897,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -901,6 +910,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -921,6 +933,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
@@ -37,6 +37,34 @@ struct TileFmhaTraits
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kHasLogitsSoftCap_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kIsPagedKV_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaFwdPagedKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = WarpSize * NumWarps;
|
||||
static constexpr index_t BlockSize = get_warp_size() * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
|
||||
@@ -23,6 +23,11 @@ namespace ck_tile {
|
||||
#define MOE_SORTING_FUSE_MP_01 0
|
||||
#endif
|
||||
|
||||
// weather use 2d buffer indexing for fmoe ws or 1d
|
||||
#ifndef MOE_SORTING_FMOE_2D_BUF
|
||||
#define MOE_SORTING_FMOE_2D_BUF 1
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
@@ -171,7 +176,7 @@ struct MoeSortingHostArgs
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens)
|
||||
// we fused the setzero of output of fused-moe buffer
|
||||
// set this pointer to nullptr will skip this operation
|
||||
void* p_moe_buf;
|
||||
@@ -182,7 +187,18 @@ struct MoeSortingHostArgs
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
@@ -197,6 +213,9 @@ struct MoeSortingKernel
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
static constexpr index_t BLOCK_SIZE = 256;
|
||||
static constexpr index_t OCCUPANCY = 2; // hard coded
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
@@ -210,8 +229,12 @@ struct MoeSortingKernel
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes;
|
||||
|
||||
#endif
|
||||
index_t tokens_per_thread;
|
||||
index_t smem_rows;
|
||||
mdiv unit_size_mdiv;
|
||||
@@ -220,10 +243,27 @@ struct MoeSortingKernel
|
||||
// mdiv sub_tokens_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
(void)h;
|
||||
return get_num_cu() * OCCUPANCY;
|
||||
#else
|
||||
// TODO: assume num-experts not too much
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
@@ -263,7 +303,12 @@ struct MoeSortingKernel
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
@@ -388,7 +433,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -431,6 +476,24 @@ struct MoeSortingKernel
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += (gridDim.x - 1) * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
index_t* p_sorted_token_ids,
|
||||
@@ -625,7 +688,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = WarpSize / Problem::ExpertTile;
|
||||
static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
@@ -693,7 +756,7 @@ struct MoeSortingKernel
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / WarpSize);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size());
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
@@ -798,7 +861,7 @@ struct MoeSortingKernel
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += WarpSize)
|
||||
for(; i_e_ < num_experts; i_e_ += get_warp_size())
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
@@ -843,7 +906,7 @@ struct MoeSortingKernel
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, WarpSize>(local_cumsum_);
|
||||
wave_cumsum<int, get_warp_size()>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
@@ -851,7 +914,7 @@ struct MoeSortingKernel
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, WarpSize>(local_masking);
|
||||
wave_cumsum<int, get_warp_size()>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
@@ -861,9 +924,10 @@ struct MoeSortingKernel
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - WarpSize) == (num_experts - 1))
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -1005,20 +1069,6 @@ struct MoeSortingKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
@@ -1029,6 +1079,25 @@ struct MoeSortingKernel
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_set_zero_kernel_2d(
|
||||
kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
|
||||
#else
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1045,6 +1114,7 @@ struct MoeSortingKernel
|
||||
kargs.smem_rows,
|
||||
smem);
|
||||
#else
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
@@ -1066,6 +1136,8 @@ namespace impl {
|
||||
// [expert, padded_tokens]
|
||||
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
|
||||
{
|
||||
// Pad to multiply of 32. This can make sure even if the mesh is in 8bit,
|
||||
// we can still use dwordx4 load/store
|
||||
constexpr index_t chunk = 32;
|
||||
return (tokens + chunk - 1) / chunk * chunk;
|
||||
};
|
||||
@@ -1109,7 +1181,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = WarpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -1261,6 +1333,24 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BLOCK_SIZE = 256>
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(
|
||||
void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
|
||||
{
|
||||
const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
|
||||
const long_index_t total_bytes = total_pixels * elem_bytes;
|
||||
const long_index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(long_index_t i = gid * BLOCK_SIZE + threadIdx.x; i < total_elems; i += blocks * BLOCK_SIZE)
|
||||
{
|
||||
p_buf[i] = zero_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
@@ -1292,12 +1382,29 @@ CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_expe
|
||||
}
|
||||
|
||||
// return size in byte
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_)
|
||||
// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp
|
||||
// kernel
|
||||
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_,
|
||||
int num_experts_,
|
||||
int topk_,
|
||||
int dispatch_policy_)
|
||||
{
|
||||
#if 1
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
// return 0;
|
||||
if(dispatch_policy_ == 0)
|
||||
{
|
||||
return 0;
|
||||
if(moe_sorting_is_oneshot(tokens_, num_experts_))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
|
||||
}
|
||||
}
|
||||
else if(dispatch_policy_ == 1)
|
||||
{
|
||||
return 0; // always use single kernel
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1308,6 +1415,98 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingClearWorkspaceKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr index_t BLOCK_SIZE = Problem::BlockSize;
|
||||
static constexpr index_t OCCUPANCY = Problem::Occu;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t num_experts;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
index_t mesh_byte_size;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.mesh_byte_size = impl::moe_sorting_mesh_byte_size(h.tokens, h.num_experts, h.topk);
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
index_t pixels = kargs.num_experts * row_size;
|
||||
index_t total_bytes = pixels * kargs.mesh_byte_size;
|
||||
index_t total_elems = total_bytes / 16; // always use dwordx4
|
||||
|
||||
using vector_type = ext_vector_t<index_t, 4>;
|
||||
vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
|
||||
auto zero_ = vector_type{0};
|
||||
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elems;
|
||||
i += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
p_expert_mesh[i] = zero_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// below kernel is multi-phase implementation for large token and/or expert case
|
||||
|
||||
// write into a buffer to record the token cnt
|
||||
@@ -1435,6 +1634,16 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
@@ -1449,12 +1658,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1479,6 +1687,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum;
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
@@ -1488,6 +1697,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1504,19 +1714,16 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
|
||||
int eid = blockIdx.x;
|
||||
|
||||
int eid = blockIdx.x;
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
|
||||
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
@@ -1524,7 +1731,32 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // will not use if not LocalToken
|
||||
}
|
||||
}();
|
||||
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
r_t* p_expert_mesh = reinterpret_cast<r_t*>(
|
||||
reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
|
||||
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -1538,7 +1770,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
{
|
||||
int position = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t v{0};
|
||||
if(position < (kargs.mesh_stride / index_pack))
|
||||
if(position < (mesh_stride / index_pack))
|
||||
v = p_expert_mesh[position];
|
||||
index_t local_sum = 0;
|
||||
static_for<0, index_pack, 1>{}(
|
||||
@@ -1546,8 +1778,8 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1560,7 +1792,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1660,7 +1892,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / WarpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1786,8 +2018,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1801,7 +2033,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / WarpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1835,7 +2067,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
@@ -1863,15 +2095,36 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -1880,7 +2133,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1888,11 +2141,21 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
kargs.tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - 1,
|
||||
gridDim.x - 1);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - 1);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
__shared__ char smem[GetSmemSize()];
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1905,8 +2168,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -1951,22 +2214,22 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2083,7 +2346,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / WarpSize) * sizeof(IndexType);
|
||||
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -2110,8 +2373,8 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
}
|
||||
}();
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2141,17 +2404,17 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2196,7 +2459,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / WarpSize + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2223,7 +2486,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
void* p_total_tokens_post_pad; // [2]
|
||||
void* p_sorted_expert_ids;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
@@ -2235,7 +2498,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
long_index_t moe_buf_bytes;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
// NOTE:
|
||||
// moe_buf_* is a 2d ws buffer used for the following fmoe kernel
|
||||
// arranged as row*col, where row=tokens(or local_token), col=interm_dim
|
||||
// we fuse this clearing inside sorting kernel
|
||||
// Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
|
||||
index_t moe_buf_interm_dim; // p_moe_buf interm_dim
|
||||
index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
|
||||
#else
|
||||
long_index_t moe_buf_bytes; // byte size of p_moe_buf
|
||||
#endif
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
@@ -2262,16 +2535,37 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
k.moe_buf_interm_dim = h.moe_buf_interm_dim;
|
||||
k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
|
||||
#else
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
#endif
|
||||
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto get_num_cu()
|
||||
{
|
||||
index_t num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
|
||||
#else
|
||||
// use 1 block to cumsum
|
||||
// return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); }
|
||||
@@ -2287,13 +2581,34 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
// reduce single pixel within a wave
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
impl::moe_buf_set_zero_kernel_2d<BLOCK_SIZE>(kargs.p_moe_buf,
|
||||
tokens,
|
||||
kargs.moe_buf_interm_dim,
|
||||
kargs.moe_buf_elem_bytes,
|
||||
blockIdx.x - kargs.num_experts,
|
||||
gridDim.x - kargs.num_experts);
|
||||
return;
|
||||
#else
|
||||
impl::moe_buf_set_zero_kernel<BLOCK_SIZE>(
|
||||
reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes,
|
||||
blockIdx.x - kargs.num_experts);
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
extern __shared__ char smem[];
|
||||
@@ -2303,15 +2618,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / WarpSize;
|
||||
index_t lane_id = threadIdx.x % WarpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -2356,22 +2671,22 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, WarpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / WarpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / WarpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2428,26 +2743,28 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
p_total_tokens_post_pad[0] = total_tokens_post_pad;
|
||||
p_total_tokens_post_pad[1] = tokens;
|
||||
}
|
||||
p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / WarpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / WarpSize;
|
||||
int lane_id = threadIdx.x % WarpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2463,14 +2780,14 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
index_t mesh_stride = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
return impl::moe_sorting_mp_mesh_stride(tokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
return kargs.mesh_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -2478,7 +2795,8 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
using d_t = ext_vector_t<index_t, index_pack>;
|
||||
int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
int loops = (mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
int prev_cumsum = 0;
|
||||
|
||||
for(int i = 0; i < loops; i++)
|
||||
@@ -2487,8 +2805,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
|
||||
}
|
||||
|
||||
r_t x_r;
|
||||
@@ -2518,17 +2835,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2569,17 +2886,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2624,17 +2941,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, WarpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == WarpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / WarpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
|
||||
@@ -73,4 +73,12 @@ struct MoeSortingProblemMp
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
};
|
||||
|
||||
template <bool LocalToken_, index_t BlockSize_ = 1024, index_t Occu_ = 1>
|
||||
struct MoeSortingClearWorkspaceProblem
|
||||
{
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t Occu = Occu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
@@ -23,6 +24,8 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
@@ -41,6 +44,8 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegV2R1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(decltype(b_warp_window_tmp){})),
|
||||
KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensors;
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
|
||||
122
include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp
Normal file
122
include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename BlockPolicy_>
|
||||
struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindow, typename BFlatBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
ABlockWindow& a_warp_windows,
|
||||
BFlatBlockTensor& b_warp_tensor) const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp =
|
||||
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmASmemBSmemCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_>
|
||||
struct BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -74,7 +74,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
|
||||
101
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Normal file → Executable file
101
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Normal file → Executable file
@@ -59,14 +59,23 @@ struct GemmHostArgs
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* e_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
@@ -187,7 +196,7 @@ struct GemmKernel
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -297,7 +306,7 @@ struct GemmKernel
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
GemmPipeline::kPadK == false) // k_batch is extra compared to flatmm
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -359,7 +368,7 @@ struct GemmKernel
|
||||
else
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
GemmPipeline::kPadK == false) // again k_batch is extra compared to flatmm
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -480,6 +489,7 @@ struct GemmKernel
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -554,12 +564,30 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
@@ -604,7 +632,7 @@ struct GemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm.
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
@@ -635,6 +663,8 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
@@ -692,8 +722,15 @@ struct GemmKernel
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view);
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// For flatmm, we need to use the flat B tensor view
|
||||
return make_tuple(a_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -723,19 +760,30 @@ struct GemmKernel
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::BlockGemmShape::flatNPerWarp>{},
|
||||
number<GemmPipeline::BlockGemmShape::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -798,7 +846,8 @@ struct GemmKernel
|
||||
a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -813,7 +862,6 @@ struct GemmKernel
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
@@ -856,7 +904,8 @@ struct GemmKernel
|
||||
a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
@@ -57,7 +57,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
@@ -95,7 +95,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
|
||||
@@ -32,6 +32,15 @@ struct GemmPipelineAgBgCrImplBase
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
async_load_tile(dst_block_window, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile,
|
||||
|
||||
@@ -112,11 +112,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
@@ -127,6 +122,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
@@ -135,6 +131,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
@@ -144,10 +145,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr bool HasHotLoop =
|
||||
Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop);
|
||||
static constexpr auto TailNum =
|
||||
Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
@@ -155,7 +159,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV3", BlockSize,
|
||||
return concat('_', "pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
|
||||
@@ -135,11 +135,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
|
||||
@@ -20,12 +20,12 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -71,6 +70,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -20,12 +20,12 @@ struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
|
||||
@@ -189,6 +189,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
// Where is the right place for HasHotLoop and TailNum ???
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
|
||||
@@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M1 = Problem::VectorSizeA;
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
@@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N1 = Problem::VectorSizeB;
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
|
||||
@@ -14,26 +14,29 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
@@ -115,7 +118,11 @@ struct GemmPipelineProblemBase
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeA_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
@@ -126,7 +133,11 @@ struct GemmPipelineProblemBase
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeB_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
@@ -153,13 +164,19 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
ComputeDataType_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -169,22 +186,31 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>; // actually AccDataType
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
@@ -192,15 +218,24 @@ struct UniversalGemmPipelineProblem
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr bool Preshuffle = Traits::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
@@ -461,10 +462,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
|
||||
@@ -28,6 +28,10 @@ struct TileGemmShape
|
||||
static constexpr bool PermuteA = PermuteA_;
|
||||
static constexpr bool PermuteB = PermuteB_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});
|
||||
static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{});
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -42,13 +42,14 @@ template <bool kPadM_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = 0>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
@@ -59,6 +60,7 @@ struct TileGemmUniversalTraits
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseWeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
|
||||
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
|
||||
if constexpr(WG::kM == 16 && WG::kN == 16)
|
||||
{
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
else if constexpr(WG::kM == 32 && WG::kN == 32 &&
|
||||
(A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BFlatBlockWindowTmp::DataType>>,
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
|
||||
constexpr auto config =
|
||||
BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_lds_gemm_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockWeightPreshuffle();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_2;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
}
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop / 2 - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
// iCounter--;
|
||||
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// global read i + 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// move to i + 2
|
||||
// move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
// move to next flat K
|
||||
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,450 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
|
||||
number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum global memory vector load size.
|
||||
*
|
||||
* @tparam Problem The UniversalGemmPipelineProblem object.
|
||||
* @tparam DataType The tensor data type we're considering.
|
||||
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
|
||||
* @tparam XPerTile The contiguous Tile dimension size.
|
||||
* @return Maximum DRAM vector load size.
|
||||
*/
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -172,7 +172,7 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
@@ -282,4 +282,19 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
// int8
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1095,16 +1095,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1119,16 +1119,16 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
@@ -1254,16 +1254,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1289,16 +1289,16 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
CVecType c_vec{0.f};
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
@@ -1578,9 +1578,9 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
@@ -1609,6 +1609,183 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 64;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||
bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
|
||||
a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
#undef DISPATCH_MFMA_
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace ck_tile {
|
||||
namespace impl {
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -22,6 +22,7 @@ struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
@@ -37,10 +38,12 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
@@ -56,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
@@ -81,12 +85,19 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -95,7 +106,7 @@ template <typename AType,
|
||||
bool UseStructuredSparsity = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
|
||||
12
include/ck_tile/ops/grouped_convolution.hpp
Normal file
12
include/ck_tile/ops/grouped_convolution.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -0,0 +1,800 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType>
|
||||
struct GroupedConvFwdKernelArgs
|
||||
{
|
||||
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType::NDimSpatial,
|
||||
GroupedConvTraitsType::ConvSpecialization>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
|
||||
args.output_spatial_lengths_[2];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
|
||||
args.filter_spatial_lengths_[2];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>())>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> wei_g_k_c_xs_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> out_g_n_k_wos_lengths;
|
||||
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_strides;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_dilations;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_left_pads;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_right_pads;
|
||||
|
||||
index_t k_batch;
|
||||
index_t GemmM;
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* out_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
BGridDescNK b_grid_desc_n_k;
|
||||
CGridDescMN c_grid_desc_m_n;
|
||||
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the grouped convolution forward kernel template. By semantic
|
||||
/// division of Implicit GEMM algorithm into following parts we achieve flexible,
|
||||
/// versatile and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of
|
||||
/// data loading from global memory and performing block-wise
|
||||
/// matrix multiplication. You can think of it as a work done by
|
||||
/// single workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using InLayout = remove_cvref_t<typename GroupedConvTraitsType::InLayout>;
|
||||
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType::WeiLayout>;
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
|
||||
{
|
||||
const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
const index_t GemmN = args.K_;
|
||||
return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvHostArgs& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
// check ConvolutionSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
{
|
||||
const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
|
||||
OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct ConvolutionSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter3x3,
|
||||
};
|
||||
|
||||
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionSpecialization::Default: return "Default";
|
||||
case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionSpecialization::Filter3x3: return "Filter3x3";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Conv kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
|
||||
const void* in_ptr_,
|
||||
const void* wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
void* out_ptr_,
|
||||
index_t k_batch_)
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
void* out_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
static constexpr auto generate_implicit_gemm_layout()
|
||||
{
|
||||
return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
|
||||
number<DsLayout_::size()>{});
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / WarpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
@@ -272,4 +272,137 @@ struct BlockReduce2dCrossWarpSync
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockReduce2dTreeCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
{
|
||||
constexpr index_t num_reduce_warps = [&]() {
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_warp = 0;
|
||||
|
||||
index_t len_ = 1;
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
len_ *= r_length;
|
||||
}
|
||||
});
|
||||
return len_;
|
||||
}();
|
||||
return num_reduce_warps;
|
||||
}
|
||||
|
||||
// return in byte
|
||||
template <typename YDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// we need to store all data from every wave into smem
|
||||
// e.g. 2x2 reduce along N
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | ___> | w01 |
|
||||
// | w2 | w3 | | w23 |
|
||||
//
|
||||
// -> store data from every wave into LDS
|
||||
//
|
||||
//
|
||||
// -------------> reduce N
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
// Each warp's lane 0 writes its partial results to shared memory
|
||||
const index_t smem_offset = warp_id;
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v = 0;
|
||||
if(lane_id < num_reduce_warps)
|
||||
{
|
||||
v = smem_ptr[lane_id + i * num_warps];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// pull data from remote lane
|
||||
const auto o =
|
||||
__shfl_xor(v, number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// reduce
|
||||
v = reduce_func(v, o);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
|
||||
|
||||
@@ -58,13 +58,14 @@ struct Rmsnorm2dFwd
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
|
||||
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
|
||||
static constexpr bool kPadM = false; // always no need to pad along M
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
|
||||
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
|
||||
@@ -150,6 +151,8 @@ struct Rmsnorm2dFwd
|
||||
if (kPadN) n += "_pn";
|
||||
if (kSaveInvRms) n += "_rms";
|
||||
if (kTwoPass) n += "_2p";
|
||||
if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL) n += "_nsm";
|
||||
else if (kUseModelSensitiveRMSNorm == Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE) n += "_t5ml";
|
||||
return n; }();
|
||||
|
||||
auto prec_str = [&] () {
|
||||
|
||||
@@ -69,6 +69,15 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dTreeCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2dTreeCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This T5Pass implements the RMSNorm2d forward pipeline as a variant
|
||||
* based on Rmsnorm2dFwdPipelineOnePass and Rmsnorm2dFwdPipelineTwoPass using a T5 model-like
|
||||
* method.
|
||||
*
|
||||
* The T5 model, developed by Google, is a transformer-based architecture designed to perform
|
||||
* a variety of NLP tasks. The T5-like approach employed here is characterized by how RMS
|
||||
* normalization is handled, particularly where intermediate values are cast to BF16. This aims to
|
||||
* achieve a similar value distribution to that produced by the VLLM hip implementation, thereby
|
||||
* enhancing model accuracy.
|
||||
*
|
||||
* Note: While this implementation improves precision and can reduce discrepancies with VLLM, it is
|
||||
* not guaranteed to eliminate all differences or ensure uniform outcomes across every use case.
|
||||
*
|
||||
* This implementation is a variant based on the original one-pass and two-pass approaches,
|
||||
* allowing for both fused and non-fused add operations.
|
||||
*/
|
||||
|
||||
template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
|
||||
struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using InvRmsDataType = ck_tile::remove_cvref_t<typename Problem::InvRmsDataType>;
|
||||
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
|
||||
static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
|
||||
static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
|
||||
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
static constexpr const char* name = []() {
|
||||
if constexpr(kNeedCrossWarpSync)
|
||||
return "bpr_op"; // block per row
|
||||
else
|
||||
return "wpr_op"; // warp per row
|
||||
}();
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
typename XResidualWindow,
|
||||
typename GammaWindow,
|
||||
typename YWindow,
|
||||
typename YResidualWindow,
|
||||
typename InvRmsWindow,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename UnquantYWindow,
|
||||
typename Epilogue>
|
||||
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
|
||||
const XResidualWindow& x_residual_window_,
|
||||
const GammaWindow& gamma_window_,
|
||||
YWindow& y_window_,
|
||||
const YResidualWindow& y_residual_window_,
|
||||
InvRmsWindow& inv_rms_window,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window_,
|
||||
UnquantYWindow& unquant_y_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t row_size,
|
||||
void* smem,
|
||||
Epilogue) const
|
||||
{
|
||||
const auto x_window =
|
||||
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto gamma_window = make_tile_window(
|
||||
gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
|
||||
const auto x_residual_window = make_tile_window(
|
||||
x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto y_residual_window = make_tile_window(
|
||||
y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto reduce_square_sum_func = ReduceOp::SquareAdd{};
|
||||
auto reduce_sum_func = ReduceOp::Add{};
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
|
||||
auto block_reduce2d_tree_cross_warp_sync =
|
||||
Policy::template GetBlockReduce2dTreeCrossWarpSync<Problem>();
|
||||
|
||||
auto x = load_tile(x_window);
|
||||
auto x_resi = load_tile(x_residual_window);
|
||||
|
||||
// load gamma (TODO: support no gamma?)
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
|
||||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
[[maybe_unused]] auto pre_out =
|
||||
make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile(x_resi, [&](auto idx) {
|
||||
// compute x = x_resi + x
|
||||
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
|
||||
|
||||
// To make norm input align with residual output
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
|
||||
}
|
||||
acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
|
||||
}
|
||||
});
|
||||
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
|
||||
{
|
||||
store_tile(y_residual_window, pre_out);
|
||||
}
|
||||
}
|
||||
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
|
||||
set_tile(square_sum, 0);
|
||||
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
|
||||
{
|
||||
sweep_tile(
|
||||
acc,
|
||||
[&](auto idx_0, auto idx_1) {
|
||||
square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
square_sum = block_reduce2d(acc,
|
||||
reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
|
||||
reduce_square_sum_func);
|
||||
}
|
||||
block_reduce2d_sync(square_sum, reduce_sum_func);
|
||||
block_reduce2d_tree_cross_warp_sync(square_sum, smem, reduce_sum_func);
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
// rmsnorm computation
|
||||
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
|
||||
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
|
||||
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 =
|
||||
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma_);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(
|
||||
unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT)
|
||||
{
|
||||
if constexpr(kSaveUnquant)
|
||||
{
|
||||
Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(y_window_, rmsn);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -117,10 +117,7 @@ struct Rmsnorm2dFwdPipelineOnePass
|
||||
|
||||
// compute inv-rms
|
||||
auto inv_rms = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
|
||||
},
|
||||
square_sum);
|
||||
[&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
|
||||
|
||||
if constexpr(kSaveInvRms)
|
||||
store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
|
||||
|
||||
@@ -37,20 +37,37 @@ template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::DYNAMIC_Q
|
||||
template<> struct Rmsnorm2dFusedQuantEnumName<Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
|
||||
// clang-format on
|
||||
|
||||
enum class Rmsnorm2dSensitiveEnum
|
||||
{
|
||||
NO_SPECIFIC_MODEL = 0,
|
||||
// T5-like model for RMSNorm. The T5 model, developed by Google, is a transformer-based
|
||||
// architecture designed for a variety of NLP tasks. This option mimics T5's approach to
|
||||
// RMSNorm, aiming to ensure similar value distributions and enhance accuracy.
|
||||
T5_MODEL_LIKE = 1,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<Rmsnorm2dSensitiveEnum> struct Rmsnorm2dSensitiveEnumName;
|
||||
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL> { static constexpr const char * name = "nsm"; };
|
||||
template<> struct Rmsnorm2dSensitiveEnumName<Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE> { static constexpr const char * name = "t5ml"; };
|
||||
// clang-format on
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
Rmsnorm2dFusedAddEnum kFusedAdd_,
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_>
|
||||
Rmsnorm2dFusedQuantEnum kFusedQuant_,
|
||||
Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm_>
|
||||
struct Rmsnorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
static constexpr Rmsnorm2dSensitiveEnum kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user