mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Weight Preshuffle Block Scale gemm support (#2877)
* initial commit
* remove extra files
* fixing errors
* updated ReadMe file for mapping of diff quants with diff configs
* addressing review comments
* addressing review comments
* Resolved merge conflicts
* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled
The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.
* initial commit
* debugging
* working fp8 for init constant
* fp8 working with all inits
* updated block level code with comments
* changing the loop iter
* debugging
* debugging
* debugging
* code fix
* code clean up
* clang formatted
* Add comment
* code cleanup
* clang formatted
* merge conflicts fixes
* applying the latest int4 changes to the piepline
* fixing test code for updated traits
* Adding gtest
* review comments addressed
* addressing review comments
* remove c++20 code
* added flush cache changes
---------
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>
[ROCm/composable_kernel commit: 81458a6681]
This commit is contained in:
@@ -77,6 +77,18 @@ struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
|
||||
{
|
||||
static constexpr bool value = T::PreshuffleQuant;
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct is_preshuffleB_enabled
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
|
||||
{
|
||||
static constexpr bool value = T::PreshuffleB;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
struct QuantGemmProblem
|
||||
@@ -196,6 +208,7 @@ struct QuantGemmKernel
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant =
|
||||
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
|
||||
static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled<GemmPipeline_>::value;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -630,12 +643,30 @@ struct QuantGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}();
|
||||
@@ -716,6 +747,8 @@ struct QuantGemmKernel
|
||||
// no padding
|
||||
const auto& aq_pad_view = [&]() { return views.at(I1); }();
|
||||
|
||||
const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
@@ -755,8 +788,14 @@ struct QuantGemmKernel
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
@@ -826,19 +865,30 @@ struct QuantGemmKernel
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
number<GemmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{0, i_n});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -969,6 +1019,80 @@ struct QuantGemmKernel
|
||||
c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
const auto& bq_block_window = gemm_tile_windows.at(I3);
|
||||
return GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
// throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
|
||||
// RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
|
||||
// "DoubleSmemBuffer Not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
|
||||
{
|
||||
@@ -989,8 +1113,35 @@ struct QuantGemmKernel
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user