mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Update flatmm related kernels (#3022)
--------- Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -11,23 +11,138 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct FlatmmProblem
|
||||
{
|
||||
CK_TILE_HOST FlatmmProblem() = default;
|
||||
CK_TILE_HOST FlatmmProblem(
|
||||
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
||||
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityK;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN / GranularityK;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
const float* ptr;
|
||||
index_t length;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
|
||||
: ptr(ptr_), length(length_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.ptr = ptr + offset;
|
||||
ret.length = length - offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
ret.length = length - offset / GranularityMN;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
// with additional oob check
|
||||
if constexpr(GranularityMN == 1)
|
||||
return i < length ? ptr[i] : 0;
|
||||
else
|
||||
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
|
||||
}
|
||||
};
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
const float* ptr = nullptr;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
return FlatmmScalePointer{};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
|
||||
{
|
||||
return 1; // alway return 1, it doesn't change the result
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct FlatmmHostArgs
|
||||
struct BaseFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST FlatmmHostArgs() = default;
|
||||
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
CK_TILE_HOST BaseFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
@@ -65,8 +180,51 @@ struct FlatmmHostArgs
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
|
||||
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
|
||||
const void* b_shuffle_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_C_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: BaseFlatmmHostArgs(a_ptr_,
|
||||
b_shuffle_ptr_,
|
||||
ds_ptr_,
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
stride_Ds_,
|
||||
stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
}
|
||||
ScaleM scale_m = nullptr;
|
||||
ScaleN scale_n = nullptr;
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
template <int NumberTensor = 0>
|
||||
using FlatmmHostArgs =
|
||||
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
|
||||
|
||||
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
|
||||
struct FlatmmKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
@@ -82,6 +240,8 @@ struct FlatmmKernelArgs
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
ScaleM scale_m_ptr = nullptr;
|
||||
ScaleN scale_n_ptr = nullptr;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -98,6 +258,7 @@ struct FlatmmKernel
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -113,7 +274,7 @@ struct FlatmmKernel
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -124,40 +285,85 @@ struct FlatmmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
assert(!UsePersistentKernel);
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = FlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
|
||||
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
||||
{
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
return {hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.scale_m,
|
||||
hostArgs.scale_n};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
|
||||
{
|
||||
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
|
||||
{
|
||||
return FlatmmPipeline::GetSmemSize();
|
||||
}
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
template <class KernelArgs>
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
@@ -173,11 +379,11 @@ struct FlatmmKernel
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
b_k_split_offset = k_id * KRead * N1;
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
@@ -195,6 +401,7 @@ struct FlatmmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
template <class KernelArgs>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
@@ -206,6 +413,14 @@ struct FlatmmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -340,7 +555,7 @@ struct FlatmmKernel
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
@@ -370,9 +585,9 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
|
||||
BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -411,7 +626,7 @@ struct FlatmmKernel
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -420,7 +635,7 @@ struct FlatmmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
@@ -429,7 +644,45 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
|
||||
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
|
||||
|
||||
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-token scale
|
||||
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
|
||||
: 1; // per-channel scale
|
||||
|
||||
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
|
||||
"only support per-tensor or per-row scaling");
|
||||
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
|
||||
"only support per-tensor or per-column scaling");
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(
|
||||
kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_m_view,
|
||||
scale_n_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -495,7 +748,12 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
|
||||
return make_tuple(a_pad_view,
|
||||
b_flat_tensor_view,
|
||||
ds_pad_view,
|
||||
e_pad_view,
|
||||
views.at(number<4>{}),
|
||||
views.at(number<5>{}));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -555,19 +813,42 @@ struct FlatmmKernel
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
|
||||
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
|
||||
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
|
||||
|
||||
auto scale_m_window = make_tile_window(views.at(number<4>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < ScaleGranularityKA == 0
|
||||
? TilePartitioner::NPerBlock
|
||||
: TilePartitioner::KPerBlock > {}),
|
||||
{i_m, 0});
|
||||
auto scale_n_window = make_tile_window(views.at(number<5>{}),
|
||||
make_tuple(number < ScaleGranularityKB == 0
|
||||
? TilePartitioner::MPerBlock
|
||||
: TilePartitioner::KPerBlock > {},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -583,50 +864,77 @@ struct FlatmmKernel
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
auto scale_m_window = gemm_tile_windows.at(number<4>{});
|
||||
auto scale_n_window = gemm_tile_windows.at(number<5>{});
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
scale_m_window,
|
||||
scale_n_window);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
do
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user