Merge commit '47cd0d5cff77658adc1c9f184c012ec3496e8214' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-19 05:12:36 +00:00
parent 142a7e067a
commit 042cd4e556
13 changed files with 183 additions and 177 deletions

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
template <class T>
struct is_pk_int4 : std::false_type
{
};
template <>
struct is_pk_int4<pk_int4_t> : std::true_type
{
};
template <typename ComputeDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
template <typename BDataType,
typename ComputeDataType,
index_t UnaryOpSize,
typename WarpTile,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
{
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
}
else
{
dst = load_tile(src);
}
}
} // namespace ck_tile

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -13,7 +14,9 @@ namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct BlockUniversalGemmAsBsCr
{
private:
@@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr
return b_block_dstr_encode;
}
private:
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
constexpr index_t UnaryOpSize = 8;
const element_wise::PassThroughPack8 elementwise_op{};
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
@@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
@@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{
@@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr
{
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
@@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{
@@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else if constexpr(ALoadTranspose)
{
@@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else if constexpr(BLoadTranspose)
{

View File

@@ -289,13 +289,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BTypeToUse =
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
@@ -202,7 +203,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -310,14 +312,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_2;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
@@ -327,7 +329,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -375,7 +378,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -408,7 +412,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -445,7 +450,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
@@ -514,7 +515,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
typename AElementFunction,
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
bool>* = nullptr>
bool>* = nullptr,
index_t UnaryOpSize_ = 8>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -631,19 +633,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
b_flat_distribution);
// pingpong buffer for B
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
@@ -659,7 +661,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -706,7 +709,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -782,7 +786,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});
@@ -862,7 +867,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
});
});

View File

@@ -5,19 +5,19 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
template <typename Problem, index_t UnaryOpSize_ = 8>
template <typename Problem>
struct BlockGemmAQuantBase
{
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
static constexpr index_t UnaryOpSize = UnaryOpSize_;
template <typename T>
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
{
@@ -42,23 +42,6 @@ struct BlockGemmAQuantBase
}
return scale_reg_f;
}
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
// A is block window on shared memory
@@ -66,7 +49,9 @@ struct BlockGemmAQuantBase
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
private:
@@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using Base = BlockGemmAQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
@@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{

View File

@@ -5,19 +5,19 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
namespace ck_tile {
template <typename Problem, index_t UnaryOpSize_ = 8>
template <typename Problem>
struct BlockGemmBQuantBase
{
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
static constexpr index_t UnaryOpSize = UnaryOpSize_;
template <typename T>
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
{
@@ -42,24 +42,6 @@ struct BlockGemmBQuantBase
}
return scale_reg_f;
}
// can be inherited from A
template <typename WarpWindow, typename WarpTile>
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
const WarpWindow& warp_window)
{
const element_wise::PassThroughPack8 elementwise_op{};
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
});
}
};
// A is block window on shared memory
@@ -67,7 +49,9 @@ struct BlockGemmBQuantBase
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
template <typename Problem_,
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
index_t UnaryOpSize_ = 8>
struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
{
private:
@@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using Base = BlockGemmBQuantBase<Problem_>;
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
@@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
}
else
{
@@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
{
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>);
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
}
else
{