mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Softmax unit-test reduction across all and non innermost dims cases. (#406)
* Add reduction across all dims cases. * host softmax: handle all reduce * Test cases when reduced dim is not innermost axis. * Fix syntax. * Test non innermost dim for fp32 and int8 * Group test suites wrt NumReduceDim. * Additionally test failing cases. * Throw error when Rank or NumReduceDims doesn't match arguments. * Check reducedDims has correct values * Move don't reuse DeviceReduceMultiblock IsSupportedArgument method. Instead implement own. (in fact just get rid of one check to enable reduction across inner dimensions). * Reorganize unit tests to better cover use scenarios. * Test input validation * Test reduction of inner dimensions with custom op instances. * Refactor fp32 and int8 unit tests. * Fix FP32 instance template parameters. * Add more instances. * Instances with InSrcVectorDim=0. * Do not initialize and copy data when arg not supported. * ckProfiler Softmax use instance factory. * Refactor device softmax IsSupported. * Additionally add non-polymorphic api functions * Split softmax instances into multiple files. * Fix profiler. * Reorganize tests to reuse profiler and cover edge cases. * Clang-format * I8 Softmax instances along with UT. * Reuse type alias definitions from instance factory header. * Clean included headers * Fix variable names. * Add missing checks in Argument constructor. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Anthony Chang <ac.chang@outlook.com>
This commit is contained in:
@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
if(Rank != inLengths.size() || Rank != inStrides.size() ||
|
||||
NumReduceDim != reduceDims.size())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"One of inLengths/inStrides/reduceDims has invalid size!"
|
||||
"\nExpected size inLengths: " +
|
||||
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
|
||||
", reduceDims: " + std::to_string(NumReduceDim) +
|
||||
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
|
||||
", inStrides: " + std::to_string(inStrides.size()) +
|
||||
", reduceDims: " + std::to_string(reduceDims.size()));
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < reduceDims.size(); ++i)
|
||||
{
|
||||
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
|
||||
{
|
||||
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
|
||||
"\nHave reduceDims[" +
|
||||
std::to_string(i) +
|
||||
"]: " + std::to_string(reduceDims[i]));
|
||||
}
|
||||
}
|
||||
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
|
||||
@@ -40,8 +40,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
AccElementwiseOp,
|
||||
Rank>
|
||||
{
|
||||
static constexpr index_t kRank = Rank;
|
||||
static constexpr index_t kNumReduceDim = NumReduceDim;
|
||||
static constexpr index_t kRank = Rank;
|
||||
static constexpr index_t kNumReduceDim = NumReduceDim;
|
||||
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
|
||||
|
||||
virtual index_t GetRank() const override { return kRank; }
|
||||
|
||||
@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
in_elementwise_op_{in_elementwise_op},
|
||||
acc_elementwise_op_{acc_elementwise_op}
|
||||
{
|
||||
if(Rank != inLengths.size() || Rank != inStrides.size() ||
|
||||
NumReduceDim != reduceDims.size())
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"One of inLengths/inStrides/reduceDims has invalid size!"
|
||||
"\nExpected size inLengths: " +
|
||||
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
|
||||
", reduceDims: " + std::to_string(NumReduceDim) +
|
||||
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
|
||||
", inStrides: " + std::to_string(inStrides.size()) +
|
||||
", reduceDims: " + std::to_string(reduceDims.size()));
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < reduceDims.size(); ++i)
|
||||
{
|
||||
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
|
||||
{
|
||||
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
|
||||
"\nHave reduceDims[" +
|
||||
std::to_string(i) +
|
||||
"]: " + std::to_string(reduceDims[i]));
|
||||
}
|
||||
}
|
||||
|
||||
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
|
||||
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
|
||||
|
||||
@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
};
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if constexpr(InSrcVectorDim == 0)
|
||||
{
|
||||
if constexpr(NumInvariantDim == 0)
|
||||
if constexpr(kNumInvariantDim == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1)
|
||||
if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
|
||||
{
|
||||
return false;
|
||||
|
||||
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0)
|
||||
}
|
||||
if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(p_arg_->inStrides_[Rank - 1] != 1)
|
||||
if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
|
||||
{
|
||||
return false;
|
||||
|
||||
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0)
|
||||
}
|
||||
if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0)
|
||||
// To improve
|
||||
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const std::vector<index_t> inLengths,
|
||||
const std::vector<index_t> inStrides,
|
||||
const std::vector<int> reduceDims,
|
||||
const AccDataType alpha,
|
||||
const AccDataType beta,
|
||||
const InDataType* in_dev,
|
||||
OutDataType* out_dev,
|
||||
InElementwiseOp in_elementwise_op,
|
||||
AccElementwiseOp acc_elementwise_op)
|
||||
{
|
||||
return Argument{inLengths,
|
||||
inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev,
|
||||
out_dev,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op};
|
||||
};
|
||||
|
||||
//
|
||||
// @brief Makes a pointer to Argument class.
|
||||
//
|
||||
@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceReduceSoftmax<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
str << "DeviceReduceSoftmax<"
|
||||
<< Rank << "," << NumReduceDim << "," << BlockSize << ","
|
||||
<< "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
|
||||
<< "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
|
||||
<< "InSrcVectorDim_" << InSrcVectorDim
|
||||
<< "_InSrcVectorSize_" << InSrcVectorSize
|
||||
<< "_OutDstVectorSize_" << OutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
Reference in New Issue
Block a user