mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add gemm weight preshuffle pk_int_t support (#2858)
* Factor out the three separate copies of load_interleaved_pk_type into a common utility class
* Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the weight preshuffle GEMM example
* Remove a duplicate function
* Add support for B tensor type pk_int4_t for the weight preshuffle GEMM, with tests included
* I4 support introduced more failing test cases that mirror the existing ones for F8
* Simplify the check for which tests to skip (they all have F8 as A tensor type)
* Add a changelog entry
* add the test for v2 wp pipeline, polish the code, add the support of int4 for v2 wp pipeline
* have a workable version for atomic add
* Revert "have a workable version for atomic add"
This reverts commit 792377a590c26cfff9c8f545d9a9e8484a7422eb.
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
[ROCm/composable_kernel commit: 47cd0d5cff]
This commit is contained in:
@@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.0.0
|
||||
|
||||
### Added
|
||||
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
|
||||
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
|
||||
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
|
||||
58
include/ck_tile/ops/common/load_interleaved_pk_type.hpp
Normal file
58
include/ck_tile/ops/common/load_interleaved_pk_type.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <class T>
|
||||
struct is_pk_int4 : std::false_type
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_pk_int4<pk_int4_t> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ComputeDataType, index_t UnaryOpSize>
|
||||
struct InterleavedPKTypeLoader
|
||||
{
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BDataType,
|
||||
typename ComputeDataType,
|
||||
index_t UnaryOpSize,
|
||||
typename WarpTile,
|
||||
typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
if constexpr(is_pk_int4<std::remove_cv_t<BDataType>>::value)
|
||||
{
|
||||
InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
}
|
||||
else
|
||||
{
|
||||
dst = load_tile(src);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -13,7 +14,9 @@ namespace ck_tile {
|
||||
// A is block window on shared memory
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
private:
|
||||
@@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
|
||||
struct BlockGemmImpl
|
||||
{
|
||||
@@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
@@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
@@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else if constexpr(ALoadTranspose)
|
||||
{
|
||||
@@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else if constexpr(BLoadTranspose)
|
||||
{
|
||||
|
||||
@@ -289,13 +289,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
|
||||
@@ -202,7 +203,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
bool>* = nullptr,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -310,14 +312,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_2;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -327,7 +329,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -375,7 +378,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -408,7 +412,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -445,7 +450,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
|
||||
@@ -514,7 +515,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
bool>* = nullptr,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -631,19 +633,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
b_flat_distribution);
|
||||
|
||||
// pingpong buffer for B
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<BTileType, KIterPerWarp>, NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// Prefetch A0
|
||||
@@ -659,7 +661,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -706,7 +709,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -782,7 +786,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -862,7 +867,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -5,19 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
template <typename Problem>
|
||||
struct BlockGemmAQuantBase
|
||||
{
|
||||
using AQDataType = remove_cvref_t<typename Problem::AQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
|
||||
static constexpr index_t UnaryOpSize = UnaryOpSize_;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
|
||||
{
|
||||
@@ -42,23 +42,6 @@ struct BlockGemmAQuantBase
|
||||
}
|
||||
return scale_reg_f;
|
||||
}
|
||||
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// A is block window on shared memory
|
||||
@@ -66,7 +49,9 @@ struct BlockGemmAQuantBase
|
||||
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
@@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmAQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -5,19 +5,19 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, index_t UnaryOpSize_ = 8>
|
||||
template <typename Problem>
|
||||
struct BlockGemmBQuantBase
|
||||
{
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
|
||||
static constexpr index_t UnaryOpSize = UnaryOpSize_;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
|
||||
{
|
||||
@@ -42,24 +42,6 @@ struct BlockGemmBQuantBase
|
||||
}
|
||||
return scale_reg_f;
|
||||
}
|
||||
|
||||
// can be inherited from A
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// A is block window on shared memory
|
||||
@@ -67,7 +49,9 @@ struct BlockGemmBQuantBase
|
||||
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
private:
|
||||
@@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using Loader = remove_cvref_t<InterleavedPKTypeLoader<ComputeDataType, UnaryOpSize_>>;
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
@@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
{
|
||||
static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>);
|
||||
Base::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -29,7 +29,8 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic)
|
||||
{256, 256, 64, 8},
|
||||
{256, 256, 64, 16}};
|
||||
|
||||
if(ck_tile::get_device_name() != "gfx950") {
|
||||
if(ck_tile::get_device_name() != "gfx950")
|
||||
{
|
||||
gemmParams.emplace_back(256, 256, 128, 2);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -91,61 +93,6 @@ void permute_tensor_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -13,6 +13,7 @@ using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using I4 = ck_tile::pk_int4_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -20,20 +21,24 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Default = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
ck_tile::GemmPipelineScheduler::Default>;
|
||||
|
||||
using WeightPreshuffle =
|
||||
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffle>;
|
||||
|
||||
// Adding alias for the F8 parameters to facilitate skipping tests.
|
||||
// This alias can be removed once test failures are fixed.
|
||||
using F8Types = std::tuple<Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffle>;
|
||||
using WeightPreshuffleV1 =
|
||||
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffleV1>;
|
||||
using WeightPreshuffleV2 =
|
||||
ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::WeightPreshuffleV2>;
|
||||
|
||||
// clang-format off
|
||||
|
||||
using KernelTypesWeightPreshuffle = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
, F8Types
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1>
|
||||
#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8
|
||||
,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>,
|
||||
std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>,
|
||||
std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1>
|
||||
#endif
|
||||
>;
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle)
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128)
|
||||
{
|
||||
if constexpr(std::is_same_v<TypeParam, F8Types>)
|
||||
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test due to failures with F8";
|
||||
}
|
||||
@@ -48,7 +48,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096)
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128)
|
||||
{
|
||||
if constexpr(std::is_same_v<TypeParam, F8Types>)
|
||||
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test due to failures with F8";
|
||||
}
|
||||
@@ -77,7 +77,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096)
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128)
|
||||
{
|
||||
if constexpr(std::is_same_v<TypeParam, F8Types>)
|
||||
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test due to failures with F8";
|
||||
}
|
||||
@@ -106,7 +106,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096)
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128)
|
||||
{
|
||||
if constexpr(std::is_same_v<TypeParam, F8Types>)
|
||||
if constexpr(std::is_same_v<std::tuple_element_t<3, TypeParam>, F8>)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test due to failures with F8";
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -34,20 +35,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
WeightPreshuffle
|
||||
WeightPreshuffleV1,
|
||||
WeightPreshuffleV2
|
||||
};
|
||||
|
||||
template <GemmPipelineType PT, typename Problem>
|
||||
struct GemmPipelineTypeSelector;
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffle, Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffleV1, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; }
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; }
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
struct GemmPipelineTypeSelector<GemmPipelineType::WeightPreshuffleV2, Problem>
|
||||
{
|
||||
using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
|
||||
static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV2"; }
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
struct config
|
||||
{
|
||||
@@ -122,7 +134,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
(PipelineType == GemmPipelineType::WeightPreshuffleV2) ? true : false;
|
||||
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
@@ -391,10 +404,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host_dev = b_shuffle_host;
|
||||
ck_tile::permute_vectors_i4x4_b(b_shuffle_host_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user