mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Support A/B Quantization in Blockscale GEMM
This commit is contained in:
@@ -113,37 +113,6 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
template <typename DataType>
|
||||
void print_tensor_elements(index_t row,
|
||||
index_t col,
|
||||
const ck_tile::HostTensor<DataType>& tensor,
|
||||
const std::string& name)
|
||||
{
|
||||
ignore = tensor;
|
||||
index_t Dim1 = row; // 第一維度 (M 或 K)
|
||||
index_t Dim2 = col; // 第二維度 (K 或 N)
|
||||
|
||||
std::cout << "\n--- 張量內容: " << name << " (" << Dim1 << "x" << Dim2 << ") ---" << std::endl;
|
||||
std::cout << std::fixed << std::setprecision(2);
|
||||
|
||||
for(index_t d1 = 0; d1 < Dim1; ++d1)
|
||||
{
|
||||
std::cout << "Row " << d1 << ": [";
|
||||
|
||||
for(index_t d2 = 0; d2 < Dim2; ++d2)
|
||||
{
|
||||
|
||||
std::cout << static_cast<float>(tensor(d1, d2));
|
||||
|
||||
if(d2 < Dim2 - 1)
|
||||
{
|
||||
std::cout << ", ";
|
||||
}
|
||||
}
|
||||
std::cout << "]" << (d1 < Dim1 - 1 ? "," : "") << std::endl;
|
||||
}
|
||||
std::cout << "---------------------------------------------------------" << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename AQDataType,
|
||||
@@ -194,8 +163,6 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
// printf("A %f m=%d k=%d\n", static_cast<float>(v_a),static_cast<int>(m)
|
||||
// ,static_cast<int>(k));
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
@@ -210,8 +177,6 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
// printf("B %f k=%d n=%d\n", static_cast<float>(v_b),static_cast<int>(k)
|
||||
// ,static_cast<int>(n));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -272,13 +237,6 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor<ADataType>& a_m_k,
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
|
||||
// print_tensor_elements<ADataType>(M, K, a_m_k, "float A (a_m_k)");
|
||||
// print_tensor_elements<BDataType>(K, N, b_k_n, "float B (b_k_n)");
|
||||
// print_tensor_elements<AQDataType>(M, K / QuantGroupSize::kK, a_q, "dequant A_q (a_q)");
|
||||
// print_tensor_elements<BQDataType>(N / QuantGroupSize::kK, K / QuantGroupSize::kK,b_q,
|
||||
// "dequant B_q (b_q)"); print_tensor_elements<CDataType>(M, N, c_m_n, "result C (c_m_n)");
|
||||
// printf("%f\n", static_cast<float>(a_m_k(0, 0)));
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
|
||||
Reference in New Issue
Block a user