mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)
- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale
[ROCm/composable_kernel commit: 82890192dd]
This commit is contained in:
78
include/ck_tile/host/permute_pk_int4.hpp
Normal file
78
include/ck_tile/host/permute_pk_int4.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Permute packed int4 vectors for device implementation compatibility
|
||||
*
|
||||
* This function transforms 4 pk_int4_t values from original layout to hardware-optimized layout:
|
||||
* - Original layout (4 pk_int4_t): 0x76543210
|
||||
* - Transformed layout (4 pk_int4_t): 0x75316420
|
||||
*
|
||||
* Each pk_int4_t contains two 4-bit values packed in the high and low nibbles of an int8_t
|
||||
*
|
||||
* Example:
|
||||
* - Input: 0x76, 0x54, 0x32, 0x10
|
||||
* - Output: 0x75, 0x31, 0x64, 0x20
|
||||
*
|
||||
* @note Input tensor length must be a multiple of 4
|
||||
*
|
||||
* This transformation is required before transferring B matrix data (of type pk_int4_t) to device.
|
||||
* The device conversion functions (i4_to_half4, i4_to_bhalf4, amd_assembly_i4_to_fp8x8,
|
||||
* amd_assembly_i4_to_bf8x8) require data in 0x75316420 order to correctly convert pk_int4_t to
|
||||
* other numeric types.
|
||||
*/
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
auto tensor_row_buf = tensor.data();
|
||||
for(size_t idx = 0; idx < tensor.size(); idx += 4)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = bit_cast<int8_t>(tensor_row_buf[idx + k]);
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 0x76543210 => 0x75316420
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 0] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 1] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 2] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor_row_buf[idx + 3] = bit_cast<pk_int4_t>(i4x2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -50,7 +50,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
@@ -63,7 +63,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user