fix a bug for int4 scale weight only kernel (#1820)

Co-authored-by: mtgu0705 <mtgu@amd.com>
This commit is contained in:
Mingtao Gu
2025-01-19 11:18:18 +08:00
committed by GitHub
parent bdddf1eace
commit 86d1b46aa6
3 changed files with 18 additions and 9 deletions

View File

@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
@@ -82,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_);
p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
}
}
else
@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
element_space_size_ / PackedSize);
}
template <typename X,
@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if(is_valid_element)
{