mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
fix a bug for int4 scale weight only kernel (#1820)
Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
@@ -29,6 +29,13 @@ struct DynamicBuffer
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||
{
|
||||
@@ -82,14 +89,18 @@ struct DynamicBuffer
|
||||
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
coherence>(
|
||||
p_data_, i, is_valid_element, element_space_size_);
|
||||
p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
coherence>(
|
||||
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||
p_data_,
|
||||
i,
|
||||
is_valid_element,
|
||||
element_space_size_ / PackedSize,
|
||||
invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -191,7 +202,7 @@ struct DynamicBuffer
|
||||
dst_buf.p_data_,
|
||||
dst_offset,
|
||||
is_valid_element,
|
||||
element_space_size_);
|
||||
element_space_size_ / PackedSize);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
@@ -226,7 +237,7 @@ struct DynamicBuffer
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
|
||||
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
|
||||
@@ -378,7 +389,7 @@ struct DynamicBuffer
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -417,7 +428,7 @@ struct DynamicBuffer
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user