mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Add column to image kernel (#930)
* Add column to image kernel * Minor fixes for dtypes and client examples * Disable tests for disabled dtypes * Disable add instances functions for disabled data types * Minor stylistic fixes * Revert "Disable add instances functions for disabled data types" This reverts commit728b869563. * Instances reduction * Add comments in device_column_to_image_impl * Update changelog and Copyrights * Improve changelog [ROCm/composable_kernel commit:e2243a4d1e]
This commit is contained in:
@@ -140,10 +140,36 @@ struct DynamicBuffer
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
{
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
this->template Set<X>(i, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template Set<X>(i, is_valid_element, tmp);
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
// handle bfloat addition
|
||||
if constexpr(is_same_v<scalar_t, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_scalar_type<X>::value)
|
||||
{
|
||||
// Scalar type
|
||||
auto result =
|
||||
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
|
||||
this->template Set<X>(i, is_valid_element, result);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Vector type
|
||||
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
const vector_type<scalar_t, vector_size> a_vector{tmp};
|
||||
const vector_type<scalar_t, vector_size> b_vector{x};
|
||||
static_for<0, vector_size, 1>{}([&](auto idx) {
|
||||
auto result = type_convert<scalar_t>(
|
||||
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
|
||||
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
|
||||
this->template Set<scalar_t>(i + idx, is_valid_element, result);
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
this->template Set<X>(i, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user