mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 2 (#2722)
Update Blockwise and Gridwise files to support both wave32 & wave64. 1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize 2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side. 3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others. 4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11. for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled.
This commit is contained in:
@@ -95,7 +95,33 @@ __device__ inline f8x4_t i4_to_f8x4(int q)
|
||||
return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3);
|
||||
}
|
||||
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); }
|
||||
__device__ inline f8x8_t i4_to_fp8x8(int q)
|
||||
{
|
||||
#if defined(__gfx12__)
|
||||
uint32_t fp8x4_0;
|
||||
uint32_t fp8x4_1;
|
||||
// todo: replace amd_assemble_cvt_f32_i4 with __builtin_amdgcn_cvt_off_f32_i4
|
||||
float f32_0 = amd_assemble_cvt_f32_i4(q);
|
||||
float f32_1 = amd_assemble_cvt_f32_i4(q >> 16);
|
||||
fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, 0, 0);
|
||||
float f32_2 = amd_assemble_cvt_f32_i4(q >> 8);
|
||||
float f32_3 = amd_assemble_cvt_f32_i4(q >> 24);
|
||||
fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, 0, 0);
|
||||
q = q >> 4;
|
||||
f32_0 = amd_assemble_cvt_f32_i4(q);
|
||||
f32_1 = amd_assemble_cvt_f32_i4(q >> 16);
|
||||
fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, fp8x4_0, 1);
|
||||
f32_2 = amd_assemble_cvt_f32_i4(q >> 8);
|
||||
f32_3 = amd_assemble_cvt_f32_i4(q >> 24);
|
||||
fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, fp8x4_1, 1);
|
||||
return bit_cast<f8x8_t>(((static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0));
|
||||
#elif defined(__gfx11__)
|
||||
ignore = q;
|
||||
return f8x8_t{};
|
||||
#else
|
||||
return amd_assembly_i4_to_fp8x8(q);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline bhalf4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user