mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
update reference gemm mx
This commit is contained in:
@@ -154,7 +154,10 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
|
||||
#include "run_gemm_mx_example.inc"
|
||||
|
||||
template <typename TypeConfig, uint32_t BlockScaleSize>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
int run_gemm_mx_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
int argc,
|
||||
char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -163,7 +166,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<TypeConfig, BlockScaleSize>(
|
||||
return run_gemm_mx_example_with_layouts<TypeConfig, BlockScaleSize>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -196,7 +199,7 @@ int run_gemm_mx_example(int argc, char* argv[])
|
||||
ck_tile::e8m0_bexp_t,
|
||||
int32_t,
|
||||
ck_tile::half_t>{});
|
||||
return run_gemm_example_prec_type<TypeConfig, 32>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_mx_example_prec_type<TypeConfig, 32>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -24,21 +24,21 @@ template <typename ADataType,
|
||||
typename BScaleLayout,
|
||||
typename CLayout,
|
||||
uint32_t BlockScaleSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& a_m_k_scale_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_scale_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
float invoke_gemm_mx(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& a_m_k_scale_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_scale_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::GemmMXKernelArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
@@ -102,13 +102,13 @@ template <typename TypeConfig,
|
||||
typename BLayout,
|
||||
typename BScaleLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AScaleLayout a_scale_layout = AScaleLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BScaleLayout b_scale_layout = BScaleLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
int run_gemm_mx_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AScaleLayout a_scale_layout = AScaleLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BScaleLayout b_scale_layout = BScaleLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
@@ -224,33 +224,33 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
XPackedDataType,
|
||||
BDataType,
|
||||
XPackedDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AScaleLayout,
|
||||
BLayout,
|
||||
BScaleLayout,
|
||||
CLayout,
|
||||
BlockScaleSize>(a_m_k_dev_buf,
|
||||
a_m_k_scale_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
b_k_n_scale_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
Scale_Stride_A,
|
||||
stride_B,
|
||||
Scale_Stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_gemm_mx<ADataType,
|
||||
XPackedDataType,
|
||||
BDataType,
|
||||
XPackedDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AScaleLayout,
|
||||
BLayout,
|
||||
BScaleLayout,
|
||||
CLayout,
|
||||
BlockScaleSize>(a_m_k_dev_buf,
|
||||
a_m_k_scale_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
b_k_n_scale_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
Scale_Stride_A,
|
||||
stride_B,
|
||||
Scale_Stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
@@ -261,13 +261,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
|
||||
ck_tile::reference_gemm_mx<ADataType, BDataType, AScaleDataType, AccDataType, CDataType>(
|
||||
a_m_k, a_m_k_scale, b_k_n, b_k_n_scale, c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
|
||||
@@ -176,192 +176,287 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename BDatatype,
|
||||
typename ScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ACCElementOp,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_mx(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<ScaleDataType>& a_m_k_scale,
|
||||
const HostTensor<BDatatype>& b_k_n,
|
||||
const HostTensor<ScaleDataType>& b_k_n_scale,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1);
|
||||
|
||||
CDataType v_c = 0;
|
||||
if constexpr(DsDataType::size() == 0)
|
||||
{
|
||||
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 1)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 2)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
|
||||
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
|
||||
}
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
|
||||
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
|
||||
if(row < M && col < N)
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
if constexpr(std::is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = a_m_k_scale(m, k / ScaleBlockSize);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
aut a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<0>{}));
|
||||
auto a_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
a_m_k_scaled(m, k) =
|
||||
ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
|
||||
ck_tile::type_convert<AccDataType>(a_m_k_scale(m, k / ScaleBlockSize));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc += v_a * v_b;
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDatatype, f4x2_pk_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = b_k_n_scale(k / ScaleBlockSize, n);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
auto b_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<0>{}));
|
||||
auto b_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<1>{}));
|
||||
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
|
||||
ck_tile::type_convert<AccDataType>(b_k_n_scale(k / ScaleBlockSize, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// call reference_gemm
|
||||
reference_gemm<ADataType, AccDataType, BDatatype, CDataType>(
|
||||
a_m_k_scaled, b_k_n_scaled, c_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
index_t batch_stride_C,
|
||||
index_t batch_count)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ACCElementOp,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void reference_gemm_multiple_d(
|
||||
const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
if constexpr(DsDataType::size() == 0)
|
||||
{
|
||||
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 1)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 2)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
|
||||
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
|
||||
}
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType * A,
|
||||
BDataType * B,
|
||||
CDataType * C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc += v_a * v_b;
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(ADataType * a_ptr,
|
||||
BDataType * b_ptr,
|
||||
CDataType * c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_batched_gemm_gpu(ADataType * a_ptr,
|
||||
BDataType * b_ptr,
|
||||
CDataType * c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
index_t batch_stride_C,
|
||||
index_t batch_count)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
|
||||
{
|
||||
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
|
||||
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
|
||||
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
|
||||
naive_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC><<<numBlocks, numThreadsPerBlock>>>(
|
||||
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user