Add BF16 tests for batched_gemm_softmax_gemm_permute (#504)

* fixed bug in softmax reference & add bf16 examples for batched_gemm_scale_softmax_gemm

* added bf16 tests for batched_gemm_softmax_gemm_permute

* changed format of device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp

* changed format device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp

* aligned annotations

* modified CMakeLists for examples

* add common example code of fp16/bf16 version for batched_gemm_scale_softmax_gemm_xdl

* use macro to control the instances

* added macro control into instances

* clang-format some files

* changed error tolerance for bf16

* changed index for 10_elementwise_normalization

* fixed xdlops code bug in amd_xdlops.hpp

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
guangzlu
2022-11-16 06:30:23 +08:00
committed by GitHub
parent db0eb1ea9c
commit 4c4c7328a6
17 changed files with 1133 additions and 269 deletions

View File

@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
};
arg.in_.ForEach([&](auto& self, auto idx) {
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
static_cast<AccDataType>(self(idx)));
reduce_max(to_sm_scalar_idx(idx)) = std::max(
reduce_max(to_sm_scalar_idx(idx)), ck::type_convert<AccDataType>(self(idx)));
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<AccDataType> in_stable(arg.in_.mDesc);
in_stable.ForEach([&](auto& self, auto idx) {
// numerator = exp(x - max(x))
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
reduce_max(to_sm_scalar_idx(idx)));
});
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
// std::endl;
arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
AccDataType temp_result =
arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
arg.beta_ * self(idx);
self(idx) = ck::type_convert<OutDataType>(temp_result);
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;

View File

@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType,
typename B1DataType,
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs;
}
};