mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Implement batched gemm bias permute for RDNA4 (#3534)
* feat: test setup for batched contraction (aka batched gemm multiple d e permute)
* wip: device struct for WMMA batched contraction multiple d based on new gridwise op
* feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases
* fix: failure to resolve template parameters when calling new function overload
* fix: passing reference type as parameter instead of underlying types
* fix: merge error caused duplicate definitions
* fix: make sure constness of template and parameters types match
* fix: don't compile batched contraction test on unsupported architectures
* feat: add example for new wmma implementation, and consolidate example code between platforms
* style: return inline instead of with branch
* chore: add extra assert on vector memory access sizes
* chore: clean up some unused variables
* fix: correct tail number calculation, added small cases and extra instances to the test
* fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method
[ROCm/composable_kernel commit: fe40a5d139]
This commit is contained in:
@@ -231,6 +231,279 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
|
||||
}
|
||||
};
|
||||
|
||||
// hardcoded for NumDimG == 1, NumDimM == 2, NumDimN == 3, NumDimK == 1
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceBatchedContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceBatchedContraction_G1_M2_N3_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_gs_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceBatchedContraction_G1_M3_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceBatchedContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceBatchedContraction_G1_M3_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_gs_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceBatchedContraction_G1_M3_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedContractionMultipleD<1,
|
||||
@@ -32,6 +33,23 @@ void add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Add>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedContractionMultipleD<1,
|
||||
2,
|
||||
3,
|
||||
1,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Add>>>& instances);
|
||||
#endif
|
||||
|
||||
// Contraction + add
|
||||
template <index_t NumDimG,
|
||||
@@ -76,10 +94,17 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ADataType, ck::half_t> && is_same_v<BDataType, ck::half_t> &&
|
||||
is_same_v<DDataType, ck::half_t> && is_same_v<EDataType, ck::half_t>)
|
||||
{
|
||||
|
||||
if constexpr(NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_batched_contraction_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_mnnm_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_batched_contraction_bias_permute_m2_n3_k1_wmma_c_shuffle_f16_f16_f16_f16_mnnm_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user