mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '6fc28ab4934d3668bf4ec96db1e082cf26b11384' into develop
This commit is contained in:
@@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
if constexpr(Traits::TransposeC) // transposed C
|
||||
{
|
||||
static_assert(false,
|
||||
"It is not supported yet to enable both Preshuffle "
|
||||
"and TransposeC.");
|
||||
// TODO:
|
||||
// A new tile distribution is needed for the Preshuffle and
|
||||
// Transpose combination. For instance, with mnk at 16x16x32, lanes
|
||||
// 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ.
|
||||
constexpr auto tbuf_offset = number<
|
||||
typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
|
||||
auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) *
|
||||
Traits::AQPerBlock +
|
||||
kQScale;
|
||||
|
||||
// cross lane ops
|
||||
uint32_t scale_reg_dword;
|
||||
|
||||
if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_reg_dword = static_cast<uint32_t>(scale_reg);
|
||||
}
|
||||
|
||||
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
|
||||
pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
|
||||
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] *
|
||||
scale_reg_f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
|
||||
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
|
||||
{
|
||||
static constexpr bool value = T::PreshuffleQuant;
|
||||
};
|
||||
|
||||
@@ -39,6 +39,7 @@ template <bool kPadM_,
|
||||
QuantType QuantType_,
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
struct TileGemmQuantTraits
|
||||
@@ -62,7 +63,7 @@ struct TileGemmQuantTraits
|
||||
using AsLayout = ALayout_;
|
||||
using BsLayout = BLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
|
||||
@@ -7,7 +7,9 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
# Typed Test Suite for GEMM Quantization
|
||||
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
|
||||
add_gtest_executable(test_tile_gemm_quant_typed
|
||||
test_gemm_quant_typed.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
|
||||
@@ -87,6 +87,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
QuantType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
GemmConfig::TransposeC,
|
||||
DoubleSmemBuffer>;
|
||||
|
||||
// Let the derived class create the appropriate pipeline and epilogue
|
||||
|
||||
@@ -41,6 +41,22 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigTransposeC : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleB
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
@@ -100,6 +116,24 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
|
||||
{
|
||||
if(t->get_lengths().size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
||||
}
|
||||
int m_ = t->get_lengths()[0];
|
||||
int aqk_ = t->get_lengths()[1];
|
||||
if(aqk_ % block_aq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
|
||||
std::copy(t->begin(), t->end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
// AQuant-specific data generation
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
@@ -150,7 +184,17 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
@@ -245,7 +289,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
@@ -701,7 +745,7 @@ class TestCkTileGemmRowColQuant
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
@@ -916,7 +960,7 @@ class TestCkTileGemmTensorQuant
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
|
||||
@@ -25,10 +25,29 @@ using GroupSize = std::integral_constant<unsigned int, 128>;
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = false
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = true && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user