mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4640 (commit 37b8c81)
Fix the Composable Kernel CI and versions incompatibility (#4640) ## Motivation This PR has 4 patches: 1. Fix the CI error of grouped gemm. 2. Fix the incompatibility of old linux version. 3. Fix the potential errors of flatmm. 4. Address the previous comments of abquant eight warps pipeline solution.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1f6768472e
commit
5cb8109535
@@ -8,6 +8,7 @@
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <numeric>
|
||||
|
||||
#include "flatmm_basic.hpp"
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
}
|
||||
|
||||
ck_tile::index_t M =
|
||||
std::reduce(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
std::accumulate(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
// round up to the multiple of BlockM
|
||||
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
});
|
||||
|
||||
@@ -35,16 +35,19 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
{
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool IS_FP8BLOCKSCALE =
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128 &&
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped &&
|
||||
(std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>);
|
||||
constexpr bool transpose_c = GemmConfig::TransposeC;
|
||||
constexpr bool eight_warps =
|
||||
IS_FP8BLOCKSCALE && BQuantGroupSize::kN == 128 &&
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
IS_FP8BLOCKSCALE && (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
GemmConfig::K_Warp_Tile == 128;
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
|
||||
using ComputeDataType =
|
||||
std::conditional_t<IS_FP8BLOCKSCALE, typename TypeConfig::ADataType, void>;
|
||||
|
||||
Reference in New Issue
Block a user