mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 10:47:00 +00:00
Support multiple D in quant gemm kernel
A new kernel is added QuantGemmMultiDKernel, the existing QuantGemmKernel behaves as usual.
This commit is contained in:
@@ -107,32 +107,16 @@ struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct QuantGemmProblem
|
||||
template <index_t NumDTensor>
|
||||
struct QuantGemmMultiDHostArgs
|
||||
{
|
||||
CK_TILE_HOST QuantGemmProblem() = default;
|
||||
CK_TILE_HOST QuantGemmProblem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
QK_A(QK_A_),
|
||||
QK_B(QK_B_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_C(stride_C_),
|
||||
stride_AQ(stride_AQ_),
|
||||
stride_BQ(stride_BQ_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* c_ptr;
|
||||
const void* aq_ptr;
|
||||
const void* bq_ptr;
|
||||
index_t k_batch;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
@@ -140,56 +124,20 @@ struct QuantGemmProblem
|
||||
index_t QK_B;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_C;
|
||||
index_t stride_AQ;
|
||||
index_t stride_BQ;
|
||||
};
|
||||
|
||||
struct QuantGemmHostArgs : public QuantGemmProblem
|
||||
{
|
||||
CK_TILE_HOST QuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* aq_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: QuantGemmProblem(
|
||||
M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
|
||||
a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
aq_ptr(aq_ptr_),
|
||||
bq_ptr(bq_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr = nullptr;
|
||||
const void* b_ptr = nullptr;
|
||||
const void* aq_ptr = nullptr;
|
||||
const void* bq_ptr = nullptr;
|
||||
void* c_ptr = nullptr;
|
||||
// k_batch must be a positive integer; defaults to 1 (no split-K).
|
||||
index_t k_batch = 1;
|
||||
};
|
||||
|
||||
struct QuantGemmKernelArgs
|
||||
template <index_t NumDTensor>
|
||||
struct QuantGemmMultiDKernelArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* aq_ptr;
|
||||
const void* bq_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
@@ -198,6 +146,7 @@ struct QuantGemmKernelArgs
|
||||
index_t QK_B;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_C;
|
||||
index_t stride_AQ;
|
||||
index_t stride_BQ;
|
||||
@@ -233,7 +182,7 @@ template <typename TilePartitioner_,
|
||||
typename EpiloguePipeline_,
|
||||
QuantType QuantType_,
|
||||
bool RuntimeSplitKTail_ = false>
|
||||
struct QuantGemmKernel
|
||||
struct QuantGemmMultiDKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
@@ -241,6 +190,7 @@ struct QuantGemmKernel
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
using AQLayout = remove_cvref_t<
|
||||
typename detail::get_aq_layout_or<GemmPipeline, typename GemmPipeline::ALayout>::type>;
|
||||
@@ -257,6 +207,7 @@ struct QuantGemmKernel
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
|
||||
|
||||
using AQDataType =
|
||||
@@ -264,15 +215,23 @@ struct QuantGemmKernel
|
||||
using BQDataType =
|
||||
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
|
||||
|
||||
static constexpr auto I0 = number<0>(); // A Tensor
|
||||
static constexpr auto I1 = number<1>(); // AQ Tensor
|
||||
static constexpr auto I2 = number<2>(); // B Tensor
|
||||
static constexpr auto I3 = number<3>(); // BQ Tensor
|
||||
static constexpr auto I4 = number<4>(); // C Tensor
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
is_detected<is_tuple, DsDataType>::value &&
|
||||
DsLayout::size() == DsDataType::size(),
|
||||
"DsLayout and DsDataType must be tuples and must have the same size.");
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
static constexpr auto kQuantType = QuantType_;
|
||||
static constexpr bool RuntimeSplitKTail = RuntimeSplitKTail_;
|
||||
|
||||
using HostArgs = QuantGemmMultiDHostArgs<NumDTensor>;
|
||||
using KernelArgs = QuantGemmMultiDKernelArgs<NumDTensor>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -290,25 +249,26 @@ struct QuantGemmKernel
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr QuantGemmKernelArgs
|
||||
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
|
||||
CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const HostArgs& hostArgs)
|
||||
{
|
||||
return QuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.aq_ptr,
|
||||
hostArgs.bq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK_A,
|
||||
hostArgs.QK_B,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.stride_BQ,
|
||||
hostArgs.k_batch};
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.aq_ptr,
|
||||
hostArgs.bq_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK_A,
|
||||
hostArgs.QK_B,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.stride_BQ,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -408,8 +368,8 @@ struct QuantGemmKernel
|
||||
public:
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
CK_TILE_DEVICE SplitKBatchOffset(const KernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
|
||||
@@ -543,7 +503,7 @@ struct QuantGemmKernel
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m)
|
||||
{
|
||||
@@ -609,7 +569,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t aq_group_offset = 0)
|
||||
@@ -796,7 +756,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_n)
|
||||
{
|
||||
@@ -935,7 +895,7 @@ struct QuantGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const index_t bq_group_offset,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
@@ -1125,9 +1085,101 @@ struct QuantGemmKernel
|
||||
return bq_block_window;
|
||||
}
|
||||
|
||||
template <typename DLayout, index_t VectorSizeD>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
|
||||
{
|
||||
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DsTensorDesc>
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const DsTensorDesc& ds_desc,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor views
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]), ds_desc[i]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 2: Create padded views
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Step 3: Create tile windows
|
||||
const auto& ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return ds_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
const auto& ds_tensor_desc = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
|
||||
kargs.M, kargs.N, kargs.stride_Ds[i]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n);
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
@@ -1180,7 +1232,7 @@ struct QuantGemmKernel
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
// k_batch must be a positive integer.
|
||||
if(kargs.k_batch <= 0)
|
||||
@@ -1438,6 +1490,59 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
bool DTensorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
||||
"NPerBlock without padding!");
|
||||
}
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
||||
"MPerBlock without padding!");
|
||||
}
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
if(!DTensorIsValid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// For RowMajor C, M is the row dimension - check M alignment here because
|
||||
@@ -1585,9 +1690,10 @@ struct QuantGemmKernel
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -1605,6 +1711,7 @@ struct QuantGemmKernel
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& bq_block_window = MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -1667,13 +1774,13 @@ struct QuantGemmKernel
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
ds_block_window,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
@@ -1683,7 +1790,7 @@ struct QuantGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1695,13 +1802,13 @@ struct QuantGemmKernel
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
ds_block_window,
|
||||
smem_ptr,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
@@ -1711,12 +1818,12 @@ struct QuantGemmKernel
|
||||
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
|
||||
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const
|
||||
CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const
|
||||
{
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
@@ -1739,8 +1846,17 @@ struct QuantGemmKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
|
||||
template <typename T, typename = void>
|
||||
@@ -1749,7 +1865,7 @@ struct QuantGemmKernel
|
||||
static constexpr bool kIsAvailableV<T, std::void_t<decltype(T::kIsAvailable)>> =
|
||||
T::kIsAvailable;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const
|
||||
CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
|
||||
{
|
||||
if constexpr(!kIsAvailableV<GemmPipeline>)
|
||||
ignore = kargs;
|
||||
@@ -1758,6 +1874,99 @@ struct QuantGemmKernel
|
||||
}
|
||||
};
|
||||
|
||||
struct QuantGemmHostArgs : public QuantGemmMultiDHostArgs<0>
|
||||
{
|
||||
CK_TILE_HOST QuantGemmHostArgs() = default;
|
||||
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const void* aq_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: QuantGemmMultiDHostArgs{a_ptr_,
|
||||
b_ptr_,
|
||||
std::array<const void*, 0>{},
|
||||
c_ptr_,
|
||||
aq_ptr_,
|
||||
bq_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
QK_A_,
|
||||
QK_B_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
std::array<index_t, 0>{},
|
||||
stride_C_,
|
||||
stride_AQ_,
|
||||
stride_BQ_}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct QuantGemmKernelArgs : public QuantGemmMultiDKernelArgs<0>
|
||||
{
|
||||
CK_TILE_HOST QuantGemmKernelArgs() = default;
|
||||
CK_TILE_HOST QuantGemmKernelArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const void* aq_ptr_,
|
||||
const void* bq_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_,
|
||||
index_t k_batch_)
|
||||
: QuantGemmMultiDKernelArgs<0>{a_ptr_,
|
||||
b_ptr_,
|
||||
aq_ptr_,
|
||||
bq_ptr_,
|
||||
std::array<const void*, 0>{},
|
||||
c_ptr_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
QK_A_,
|
||||
QK_B_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
std::array<index_t, 0>{},
|
||||
stride_C_,
|
||||
stride_AQ_,
|
||||
stride_BQ_,
|
||||
k_batch_}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
QuantType QuantType_,
|
||||
bool RuntimeSplitKTail_ = false>
|
||||
using QuantGemmKernel = QuantGemmMultiDKernel<TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_,
|
||||
QuantType_,
|
||||
RuntimeSplitKTail_>;
|
||||
|
||||
} // namespace ck_tile
|
||||
#if __clang_major__ >= 23
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -366,6 +366,7 @@ struct QuantGroupedGemmKernel
|
||||
{b_ptr},
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
{},
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
|
||||
Reference in New Issue
Block a user