changed gemm_b_scale and gemm_universal tests to use correct parameters

This commit is contained in:
Kevin Abraham
2025-08-20 13:48:14 +00:00
committed by Kevin Abraham
parent c35aee2b56
commit ea36b9eead
3 changed files with 54 additions and 17 deletions

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include <stdexcept>
namespace ck {
namespace tensor_operation {
@@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch=1)
: a_g_m_k_{a_g_m_k},
b_g_k_n_{b_g_k_n},
c_g_m_n_{c_g_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
k_batch_(k_batch)
{
if(k_batch < 1)
throw std::invalid_argument("Batch size must be at least 1");
}
const Tensor<ADataType>& a_g_m_k_;
@@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
const int k_batch_;
};
// Invoker
@@ -59,23 +66,52 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
AccDataType v_acc = 0;
// simulate fp accuacy implications of k batching
std::vector<CDataType> partialSums(arg.k_batch_);
for(int k = 0; k < K; ++k)
for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx)
{
ADataType v_a;
BDataType v_b;
int batchSize = std::max(K/arg.k_batch_, 1);
int batchStart = batchSize*batchIdx;
int batchEnd = batchSize*(batchIdx+1);
// add any extra round-off to last batch
if(batchIdx == arg.k_batch_-1)
batchEnd = K;
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
AccDataType v_acc = 0;
for(int k = batchStart; k < batchEnd; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
partialSums[batchIdx] = ck::type_convert<CDataType>(v_c);
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
// finally, sum up partial sums
// note that we can't simulate the random nature of atomic additions, but at least we can
// simulate the effect of partial sums
AccDataType v_c = 0;
if(arg.k_batch_ > 1)
{
for(int batchIdx = 0;batchIdx < arg.k_batch_;batchIdx++)
{
v_c = ck::type_convert<CDataType>(ck::type_convert<AccDataType>(v_c) + ck::type_convert<AccDataType>(partialSums[batchIdx]));
}
}
else
{
v_c = partialSums[0];
}
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
};
@@ -108,9 +144,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator
Tensor<CDataType>& c_g_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const int k_batch=1)
{
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch};
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
break;
default:

View File

@@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification,
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});