[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)

[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion
This commit is contained in:
yinglu
2026-03-19 09:19:06 +00:00
committed by assistant-librarian[bot]
parent 652d3456ca
commit d460ab35b6
30 changed files with 1164 additions and 260 deletions

View File

@@ -4,11 +4,11 @@
#pragma once
#include <cstdlib>
#include <mutex>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ck_tile {
@@ -447,24 +447,34 @@ CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void
reference_gemm(const HostTensor<if_select_t<ADataType_, tf32_t, float, ADataType_>>& a_m_k,
const HostTensor<if_select_t<BDataType_, tf32_t, float, BDataType_>>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
const bool is_gfx950 = (ck_tile::get_device_name() == "gfx950");
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
@@ -472,7 +482,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by
// PackedSize So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
@@ -481,7 +491,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
}
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = a_m_k(m, k);
@@ -493,7 +503,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
// HostTensor automatically handles packed indexing
const pk_fp4_t pk_val = b_k_n(k, n);
@@ -501,7 +511,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
}
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = b_k_n(k, n);
@@ -513,7 +523,36 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
if(is_gfx950)
{
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small = ck_tile::type_convert<bf16_t>(
v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small = ck_tile::type_convert<bf16_t>(
v_b - type_convert<AccDataType>(v_b_bf16_big));
v_acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
}
else
{
// Other architectures: tf32 not supported or handled via fp32 fallback
v_acc += v_a * v_b;
}
}
else
{
v_acc += v_a * v_b;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
@@ -764,15 +803,15 @@ reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
__global__ void naive_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
@@ -781,6 +820,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
@@ -790,8 +837,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
@@ -802,7 +849,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
@@ -810,7 +857,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
@@ -822,7 +869,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
@@ -830,7 +877,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
@@ -842,7 +889,33 @@ __global__ void naive_gemm_kernel(ADataType* A,
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
#ifdef CK_GFX950_SUPPORT
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small =
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small =
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
#else
// Other architectures: use fp32 fallback
acc += v_a * v_b;
#endif
}
else
{
acc += v_a * v_b;
}
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
@@ -852,15 +925,15 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void blockwise_gemm_kernel(ADataType* A,
BDataType* B,
__global__ void blockwise_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
@@ -874,6 +947,14 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
float* scale_A_ptr,
float* scale_B_ptr)
{
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
static_assert(std::is_same_v<ADataType_, BDataType_>,
"ADataType and BDataType must be the same");
using ADataTypeCompute = ADataType_;
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
@@ -902,8 +983,8 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
@@ -914,7 +995,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
@@ -922,7 +1003,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
@@ -935,7 +1016,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
@@ -943,7 +1024,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
@@ -955,7 +1036,33 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc_temp += v_a * v_b;
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
{
#ifdef CK_GFX950_SUPPORT
// gfx950: use 3x bf16 emulation
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
bf16_t v_a_bf16_small =
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
bf16_t v_b_bf16_small =
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
acc_temp += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
#else
// Other architectures: use fp32 fallback
acc_temp += v_a * v_b;
#endif
}
else
{
acc_temp += v_a * v_b;
}
}
// final accumulation
acc += acc_temp * scale_A * scale_B;
@@ -974,8 +1081,8 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1002,8 +1109,8 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_blockwise_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1040,15 +1147,15 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
return;
}
template <typename ADataType,
typename BDataType,
template <typename ADataType_,
typename BDataType_,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
void reference_batched_gemm_gpu(if_select_t<ADataType_, tf32_t, float, ADataType_>* a_ptr,
if_select_t<BDataType_, tf32_t, float, BDataType_>* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
@@ -1061,18 +1168,29 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
index_t batch_stride_C,
index_t batch_count)
{
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
using ADataTypeCompute = ADataType_;
using BDataTypeCompute = BDataType_;
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
ADataTypeBuf* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataTypeBuf* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataTypeCompute,
BDataTypeCompute,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC><<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
return;