mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Weight Preshuffle Block Scale gemm support (#2877)
* initial commit
* remove extra files
* fixing errors
* updated ReadMe file for mapping of diff quants with diff configs
* addressing review comments
* addressing review comments
* Resolved merge conflicts
* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled
The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.
* initial commit
* debugging
* working fp8 for init constant
* fp8 working with all inits
* updated block level code with comments
* changing the loop iter
* debugging
* debugging
* debugging
* code fix
* code clean up
* clang formatted
* Add comment
* code cleanup
* clang formatted
* merge conflicts fixes
* applying the latest int4 changes to the piepline
* fixing test code for updated traits
* Adding gtest
* review comments addressed
* addressing review comments
* remove c++20 code
* added flush cache changes
---------
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>
[ROCm/composable_kernel commit: 81458a6681]
This commit is contained in:
@@ -53,6 +53,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
|
||||
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
|
||||
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
|
||||
public:
|
||||
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
|
||||
@@ -62,10 +65,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
// Common test execution logic
|
||||
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
constexpr bool kPreshuffle = false;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -77,11 +79,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
kPreshuffle,
|
||||
PreshuffleQuant,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantType>;
|
||||
QuantType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DoubleSmemBuffer>;
|
||||
|
||||
// Let the derived class create the appropriate pipeline and epilogue
|
||||
static_cast<Derived*>(this)
|
||||
@@ -125,6 +131,19 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
};
|
||||
|
||||
// Define generic QuantTypeTraits template (will be specialized)
|
||||
|
||||
@@ -24,6 +24,7 @@ struct GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
@@ -40,6 +41,41 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleB
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 64;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
@@ -288,6 +324,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
static constexpr auto PreshuffleB = Base::PreshuffleB;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
@@ -324,16 +361,23 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> temp = b_k_n;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
b_k_n_dev_buf.ToDevice(temp.data());
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
b_k_n_dev = this->shuffle_b(b_k_n);
|
||||
}
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
b_k_n_dev = this->shuffle_b(b_k_n);
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
|
||||
@@ -419,7 +463,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
PreshuffleB == false,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
@@ -443,7 +490,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PreshuffleB == false,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -486,6 +537,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmPreshuffleBBQuant : public TestCkTileGemmBQuant<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// RowColQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmRowColQuant
|
||||
|
||||
@@ -41,6 +41,14 @@ using BQuantTypes = ::testing::Types<
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>
|
||||
>;
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
|
||||
@@ -58,6 +66,7 @@ using TensorQuantTypes = ::testing::Types<
|
||||
// Test suites for each quantization type
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
// RowColQuant tests
|
||||
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user