mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)
- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale
[ROCm/composable_kernel commit: 82890192dd]
This commit is contained in:
@@ -344,6 +344,24 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -90,61 +92,6 @@ void permute_tensor_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename ADataType,
|
||||
@@ -399,7 +346,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -5,7 +5,7 @@ KNAME=1
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
run_tests() {
|
||||
for m in 512 1024; do
|
||||
@@ -32,5 +32,8 @@ run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
run_tests "fp16i4"
|
||||
run_tests "fp8i4"
|
||||
run_tests "bf8i4"
|
||||
|
||||
set +x
|
||||
|
||||
@@ -5,11 +5,8 @@
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "run_gemm_example_common.hpp"
|
||||
@@ -58,7 +55,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
@@ -74,6 +71,36 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -228,4 +228,4 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "../00_shared/host_tensor_utils.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -217,7 +218,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -208,7 +209,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
@@ -308,7 +309,17 @@ int run_gemm_example_with_layouts(int argc,
|
||||
aq_dev_buf.ToDevice(aq_tensor.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
Reference in New Issue
Block a user