mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Support A/B Quantization in Blockscale GEMM (#3343)
* Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Implement review suggested changes * Sync with develop * fix pre-commit error * Add unit tests for blockscale AB-Quantization * fix pre-commit error * fix pre-commit error * fix compile error * fix compile error * fix clang-format * fix clang-format * fix enumeration values not handled in switch * rebase file * Add missing enums to data_type_sizeof (#3430) Fixes broken build on gfx942. This was some test code that got merged at the same time. * [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419) * Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen * Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts --------- Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> * Add build trace diagnostics to CI. (#3432) * generate and visualize build traces for all archs * generate build traces in all cases * fix jenkins logic * fix typo * use more threads for parsing dependency map * add script to parse ninja traces and issue warnings * fix python script syntax and header * fix python syntax one more time * fix python syntax * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Sync with develop * Add unit tests for blockscale AB-Quantization * fix enumeration values not handled in switch * rebase file * rebase file --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com> Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -412,7 +412,8 @@ struct QuantGemmKernel
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped)
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
|
||||
{
|
||||
@@ -424,7 +425,8 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
|
||||
{
|
||||
@@ -651,7 +653,9 @@ struct QuantGemmKernel
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
|
||||
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::ABQuantGrouped) &&
|
||||
!PreshuffleQuant)
|
||||
{
|
||||
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -831,6 +835,17 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type for this
|
||||
@@ -1007,6 +1022,17 @@ struct QuantGemmKernel
|
||||
{0, i_m});
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
|
||||
constexpr auto block_m = TilePartitioner::MPerBlock;
|
||||
constexpr auto block_k = TilePartitioner::KPerBlock;
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
return make_tile_window(aq_pad_view,
|
||||
@@ -1104,6 +1130,16 @@ struct QuantGemmKernel
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr; // TODO: use some other "empty" type here
|
||||
@@ -1184,6 +1220,26 @@ struct QuantGemmKernel
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
index_t m = 0;
|
||||
index_t n = 0;
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
m = kargs.M;
|
||||
n = kargs.N;
|
||||
}
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
aq_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
m,
|
||||
n);
|
||||
}
|
||||
else if constexpr(kQuantType == QuantType::RowColQuant ||
|
||||
kQuantType == QuantType::TensorQuant)
|
||||
{
|
||||
@@ -1195,7 +1251,8 @@ struct QuantGemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
if constexpr(kQuantType == QuantType::AQuantGrouped ||
|
||||
if constexpr(kQuantType == QuantType::ABQuantGrouped ||
|
||||
kQuantType == QuantType::AQuantGrouped ||
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
|
||||
Reference in New Issue
Block a user