mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
Reorganize files, Part 1 (#119)
* delete obselete files * move files * build * update cmake * update cmake * fix build * reorg examples * update cmake for example and test
This commit is contained in:
298
include/ck/utility/amd_xdlops.hpp
Normal file
298
include/ck/utility/amd_xdlops.hpp
Normal file
@@ -0,0 +1,298 @@
|
||||
#ifndef CK_AMD_XDLOPS_HPP
|
||||
#define CK_AMD_XDLOPS_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x2f32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x1f32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x1f32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
// fp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8f16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x16f16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x4f16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x4f16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
|
||||
// bfp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4bf16<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x8bf16<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_32x32x8i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_32x32x8i8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_16x16x16i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user