mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Add BF16 tests for batched_gemm_softmax_gemm_permute (#504)
* fixed bug in softmax reference & add bf16 examples for batched_gemm_scale_softmax_gemm * added bf16 tests for batched_gemm_softmax_gemm_permute * changed format of device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp * changed format device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp * aligned annotations * modified CMakeLists for examples * add common example code of fp16/bf16 version for batched_gemm_scale_softmax_gemm_xdl * use macro to control the instances * added macro control into instances * clang-format some files * changed error tolerance for bf16 * changed index for 10_elementwise_normalization * fixed xdlops code bug in amd_xdlops.hpp Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
};
|
||||
|
||||
arg.in_.ForEach([&](auto& self, auto idx) {
|
||||
reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)),
|
||||
static_cast<AccDataType>(self(idx)));
|
||||
reduce_max(to_sm_scalar_idx(idx)) = std::max(
|
||||
reduce_max(to_sm_scalar_idx(idx)), ck::type_convert<AccDataType>(self(idx)));
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
|
||||
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
Tensor<AccDataType> in_stable(arg.in_.mDesc);
|
||||
in_stable.ForEach([&](auto& self, auto idx) {
|
||||
// numerator = exp(x - max(x))
|
||||
self(idx) = std::exp(static_cast<AccDataType>(arg.in_(idx)) -
|
||||
self(idx) = std::exp(ck::type_convert<AccDataType>(arg.in_(idx)) -
|
||||
reduce_max(to_sm_scalar_idx(idx)));
|
||||
});
|
||||
|
||||
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
|
||||
// std::endl;
|
||||
|
||||
arg.out_.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
|
||||
arg.beta_ * self(idx);
|
||||
AccDataType temp_result =
|
||||
arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) +
|
||||
arg.beta_ * self(idx);
|
||||
self(idx) = ck::type_convert<OutDataType>(temp_result);
|
||||
});
|
||||
|
||||
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
|
||||
|
||||
@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskOutUpperTriangle>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MaskingSpecialization::MaskDisabled>>>&
|
||||
instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
|
||||
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
|
||||
{
|
||||
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
|
||||
{
|
||||
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
|
||||
{
|
||||
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user