mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
changed gemm_b_scale and gemm_universal tests to use correct parameters
This commit is contained in:
committed by
Kevin Abraham
parent
c35aee2b56
commit
ea36b9eead
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -30,14 +31,18 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
Tensor<CDataType>& c_g_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const int k_batch=1)
|
||||
: a_g_m_k_{a_g_m_k},
|
||||
b_g_k_n_{b_g_k_n},
|
||||
c_g_m_n_{c_g_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
k_batch_(k_batch)
|
||||
{
|
||||
if(k_batch < 1)
|
||||
throw std::invalid_argument("Batch size must be at least 1");
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_g_m_k_;
|
||||
@@ -47,6 +52,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
|
||||
const int k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -59,23 +66,52 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
|
||||
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
// simulate fp accuacy implications of k batching
|
||||
std::vector<CDataType> partialSums(arg.k_batch_);
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
for(int batchIdx = 0; batchIdx < arg.k_batch_; ++batchIdx)
|
||||
{
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
int batchSize = std::max(K/arg.k_batch_, 1);
|
||||
int batchStart = batchSize*batchIdx;
|
||||
int batchEnd = batchSize*(batchIdx+1);
|
||||
// add any extra round-off to last batch
|
||||
if(batchIdx == arg.k_batch_-1)
|
||||
batchEnd = K;
|
||||
|
||||
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
|
||||
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
|
||||
AccDataType v_acc = 0;
|
||||
for(int k = batchStart; k < batchEnd; ++k)
|
||||
{
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
|
||||
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
|
||||
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
partialSums[batchIdx] = ck::type_convert<CDataType>(v_c);
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
// finally, sum up partial sums
|
||||
// note that we can't simulate the random nature of atomic additions, but at least we can
|
||||
// simulate the effect of partial sums
|
||||
AccDataType v_c = 0;
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
for(int batchIdx = 0;batchIdx < arg.k_batch_;batchIdx++)
|
||||
{
|
||||
v_c = ck::type_convert<CDataType>(ck::type_convert<AccDataType>(v_c) + ck::type_convert<AccDataType>(partialSums[batchIdx]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_c = partialSums[0];
|
||||
}
|
||||
|
||||
arg.c_g_m_n_(g, m, n) = ck::type_convert<CDataType>(v_c);
|
||||
};
|
||||
@@ -108,9 +144,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator
|
||||
Tensor<CDataType>& c_g_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const int k_batch=1)
|
||||
{
|
||||
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op};
|
||||
return Argument{a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, c_element_op, k_batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -105,7 +105,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -90,7 +90,7 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 2});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
|
||||
Reference in New Issue
Block a user