mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)
[CK_TILE] add tf32 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in CK_TILE on gfx942 and gfx950. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [ ] Any dependent changes have been merged ## Discussion
This commit is contained in:
committed by
assistant-librarian[bot]
parent
652d3456ca
commit
d460ab35b6
@@ -4,11 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -447,24 +447,34 @@ CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
CK_TILE_HOST void
|
||||
reference_gemm(const HostTensor<if_select_t<ADataType_, tf32_t, float, ADataType_>>& a_m_k,
|
||||
const HostTensor<if_select_t<BDataType_, tf32_t, float, BDataType_>>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
|
||||
static_assert(std::is_same_v<ADataType_, BDataType_>,
|
||||
"ADataType and BDataType must be the same");
|
||||
using ADataTypeCompute = ADataType_;
|
||||
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
|
||||
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
|
||||
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
const bool is_gfx950 = (ck_tile::get_device_name() == "gfx950");
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
@@ -472,7 +482,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by
|
||||
// PackedSize So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
|
||||
@@ -481,7 +491,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
else if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = a_m_k(m, k);
|
||||
@@ -493,7 +503,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_fp4_t pk_val = b_k_n(k, n);
|
||||
@@ -501,7 +511,7 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
else if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = b_k_n(k, n);
|
||||
@@ -513,7 +523,36 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_acc += v_a * v_b;
|
||||
|
||||
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
|
||||
{
|
||||
if(is_gfx950)
|
||||
{
|
||||
// gfx950: use 3x bf16 emulation
|
||||
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
|
||||
bf16_t v_a_bf16_small = ck_tile::type_convert<bf16_t>(
|
||||
v_a - type_convert<AccDataType>(v_a_bf16_big));
|
||||
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
|
||||
bf16_t v_b_bf16_small = ck_tile::type_convert<bf16_t>(
|
||||
v_b - type_convert<AccDataType>(v_b_bf16_big));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Other architectures: tf32 not supported or handled via fp32 fallback
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
@@ -764,15 +803,15 @@ reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
__global__ void naive_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
|
||||
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
@@ -781,6 +820,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
|
||||
static_assert(std::is_same_v<ADataType_, BDataType_>,
|
||||
"ADataType and BDataType must be the same");
|
||||
using ADataTypeCompute = ADataType_;
|
||||
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
|
||||
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
|
||||
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
@@ -790,8 +837,8 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
AccDataType acc = 0.0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
@@ -802,7 +849,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
@@ -810,7 +857,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
@@ -822,7 +869,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
@@ -830,7 +877,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
@@ -842,7 +889,33 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc += v_a * v_b;
|
||||
|
||||
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
|
||||
{
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
// gfx950: use 3x bf16 emulation
|
||||
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
|
||||
bf16_t v_a_bf16_small =
|
||||
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
|
||||
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
|
||||
bf16_t v_b_bf16_small =
|
||||
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
|
||||
|
||||
acc += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
|
||||
#else
|
||||
// Other architectures: use fp32 fallback
|
||||
acc += v_a * v_b;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
@@ -852,15 +925,15 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
__global__ void blockwise_gemm_kernel(if_select_t<ADataType_, tf32_t, float, ADataType_>* A,
|
||||
if_select_t<BDataType_, tf32_t, float, BDataType_>* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
@@ -874,6 +947,14 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType_, tf32_t> || std::is_same_v<BDataType_, tf32_t>)
|
||||
static_assert(std::is_same_v<ADataType_, BDataType_>,
|
||||
"ADataType and BDataType must be the same");
|
||||
using ADataTypeCompute = ADataType_;
|
||||
// ADataTypeBuf: buffer/storage type (fp32 when tf32)
|
||||
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
|
||||
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
|
||||
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
@@ -902,8 +983,8 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataTypeBuf>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataTypeBuf>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
@@ -914,7 +995,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
@@ -922,7 +1003,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
else if constexpr(std::is_same_v<ADataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
@@ -935,7 +1016,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataTypeBuf, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
@@ -943,7 +1024,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
else if constexpr(std::is_same_v<BDataTypeBuf, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
@@ -955,7 +1036,33 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
|
||||
if constexpr(std::is_same_v<ADataTypeCompute, tf32_t>)
|
||||
{
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
// gfx950: use 3x bf16 emulation
|
||||
bf16_t v_a_bf16_big = ck_tile::type_convert<bf16_t>(v_a);
|
||||
bf16_t v_a_bf16_small =
|
||||
ck_tile::type_convert<bf16_t>(v_a - type_convert<AccDataType>(v_a_bf16_big));
|
||||
bf16_t v_b_bf16_big = ck_tile::type_convert<bf16_t>(v_b);
|
||||
bf16_t v_b_bf16_small =
|
||||
ck_tile::type_convert<bf16_t>(v_b - type_convert<AccDataType>(v_b_bf16_big));
|
||||
|
||||
acc_temp += ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck_tile::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck_tile::type_convert<AccDataType>(v_b_bf16_big);
|
||||
#else
|
||||
// Other architectures: use fp32 fallback
|
||||
acc_temp += v_a * v_b;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_temp += v_a * v_b;
|
||||
}
|
||||
}
|
||||
// final accumulation
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
@@ -974,8 +1081,8 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
void reference_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
|
||||
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
@@ -1002,8 +1109,8 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
void reference_blockwise_gemm_gpu(if_select_t<ADataType, tf32_t, float, ADataType>* a_ptr,
|
||||
if_select_t<BDataType, tf32_t, float, BDataType>* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
@@ -1040,15 +1147,15 @@ void reference_blockwise_gemm_gpu(ADataType* a_ptr,
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
void reference_batched_gemm_gpu(if_select_t<ADataType_, tf32_t, float, ADataType_>* a_ptr,
|
||||
if_select_t<BDataType_, tf32_t, float, BDataType_>* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
@@ -1061,18 +1168,29 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
index_t batch_stride_C,
|
||||
index_t batch_count)
|
||||
{
|
||||
using ADataTypeBuf = if_select_t<ADataType_, tf32_t, float, ADataType_>;
|
||||
using BDataTypeBuf = if_select_t<BDataType_, tf32_t, float, BDataType_>;
|
||||
|
||||
using ADataTypeCompute = ADataType_;
|
||||
using BDataTypeCompute = BDataType_;
|
||||
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
|
||||
{
|
||||
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
ADataTypeBuf* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataTypeBuf* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
naive_gemm_kernel<ADataTypeCompute,
|
||||
BDataTypeCompute,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC><<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
}
|
||||
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user