mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast
This commit is contained in:
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||
p_a, p_b, reg_c);
|
||||
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -53,9 +54,9 @@ __device__ void
|
||||
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
@@ -114,11 +115,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
@@ -160,11 +161,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
|
||||
{
|
||||
|
||||
// TODO remove pointer casting
|
||||
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
|
||||
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
||||
@@ -184,11 +185,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
|
||||
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
||||
|
||||
@@ -51,7 +51,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -81,7 +81,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -34,7 +34,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
first = false;
|
||||
else
|
||||
os << delim;
|
||||
os << T{v};
|
||||
os << static_cast<T>(v);
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user