mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Merge branch 'dev/ck_moe_gemm' into ck_moe_merge
This commit is contained in:
@@ -246,13 +246,6 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
|
||||
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
|
||||
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
|
||||
if(HAS_ENABLE_POST_MISCHED)
|
||||
message("Adding the enable-post-misched=0 compiler flag")
|
||||
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
|
||||
endif()
|
||||
endif()
|
||||
set(check-coerce)
|
||||
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
|
||||
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
@@ -534,7 +527,6 @@ include_directories(BEFORE
|
||||
${HIP_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
|
||||
3
Jenkinsfile
vendored
3
Jenkinsfile
vendored
@@ -722,6 +722,9 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
triggers {
|
||||
parameterizedCron(CRON_SETTINGS)
|
||||
}
|
||||
options {
|
||||
parallelsAlwaysFailFast()
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 128;
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
128, 16, 16,
|
||||
16, 128,
|
||||
256, 16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
1, 2,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 16, 1, 16>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -80,11 +80,12 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool flush_cache = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
@@ -100,7 +101,7 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 8)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -110,16 +111,19 @@ int main(int argc, char* argv[])
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -182,9 +186,15 @@ int main(int argc, char* argv[])
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
@@ -194,6 +204,16 @@ int main(int argc, char* argv[])
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
#endif
|
||||
#if 0
|
||||
for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){
|
||||
float row_sum = .0;
|
||||
for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){
|
||||
printf("%lf ",a1_m_k(im, ik));
|
||||
row_sum += a1_m_k(im, ik);
|
||||
}
|
||||
printf("sum: %lf\n", row_sum * 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
@@ -239,12 +259,24 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float ave_time = .0;
|
||||
|
||||
if(flush_cache)
|
||||
{
|
||||
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
|
||||
|
||||
ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
@@ -176,7 +176,8 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
);
|
||||
}}
|
||||
|
||||
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
@@ -412,14 +413,26 @@ class FmhaBwdDQDKDVKernel:
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_dbias == 't' : n += '_dbias'
|
||||
else: n += '_ndbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
@@ -489,7 +502,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
|
||||
if kernel_filter != None:
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Flash attention integration
|
||||
@@ -517,23 +530,19 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
elif receipt == 10:
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "t"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 11:
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "t"
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
@@ -632,13 +641,14 @@ class FmhaBwdOGradDotOKernel:
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
@@ -657,6 +667,21 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
|
||||
k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype,
|
||||
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
|
||||
F_occupancy=get_occupancy(dtype, hdim))
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
@@ -766,14 +791,16 @@ class FmhaBwdConvertQGradKernel:
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_deterministic == 't' : n += f'_deterministic'
|
||||
else: n += '_npad'
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
|
||||
def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
@@ -792,6 +819,21 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
|
||||
continue
|
||||
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
|
||||
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
@@ -808,27 +850,33 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_convert_dq_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
kernels = get_bwd_convert_dq_blobs()
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -233,14 +233,26 @@ class FmhaFwdPipeline:
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
@@ -484,7 +496,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
@@ -504,20 +516,18 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 10:
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 11:
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
@@ -532,13 +542,13 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
|
||||
@@ -323,12 +323,11 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
# 12 - Aiter(mha_fwd_kvcache) integration
|
||||
if receipt in (2, 12):
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
|
||||
@@ -397,14 +397,26 @@ class FmhaFwdSplitKVPipeline:
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
else: n += '_npagedkv'
|
||||
return n
|
||||
|
||||
@dataclass
|
||||
@@ -702,7 +714,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Flash attention integration
|
||||
@@ -714,20 +726,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 11:
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd_kvcache) integration
|
||||
elif receipt == 12:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
@@ -780,9 +782,15 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline)
|
||||
if kernel_filter != None:
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
if receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
@@ -794,21 +802,27 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -
|
||||
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
|
||||
file_path.write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
|
||||
|
||||
@@ -452,4 +452,5 @@ struct fmha_bwd_traits
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -30,7 +30,7 @@ handlers = dict(
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
@@ -38,19 +38,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter :
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api in api_list:
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, mask_impl)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api in api_list:
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, mask_impl)
|
||||
|
||||
@@ -84,6 +84,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default='',
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
@@ -105,15 +106,19 @@ if __name__ == "__main__":
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration\n" + \
|
||||
" 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \
|
||||
" 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \
|
||||
" 12: Only generate instance for Aiter(mha_fwd_kvcache) integration"
|
||||
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
|
||||
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(',')
|
||||
filter_list = args.filter.split(',')
|
||||
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask)
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -35,11 +35,76 @@
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
struct GemmConfig
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
@@ -49,7 +114,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
@@ -58,7 +123,7 @@ struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
@@ -67,7 +132,7 @@ struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
@@ -76,7 +141,7 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
@@ -29,8 +29,67 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
template <typename Tensor>
|
||||
|
||||
template <typename Tensor,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void permute_tensor_b(Tensor& tensor)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GEMM_PIPELINE_SCHEDULER,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const ck_tile::index_t K0 = K / K1;
|
||||
|
||||
Tensor tensor_copy = tensor;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
@@ -153,7 +212,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = typename GemmBasicTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -181,8 +240,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -204,18 +263,36 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute data for device implementation
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
permute_tensor_b(b_k_n_dev);
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_basic.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -21,90 +21,39 @@ template <typename ADataType,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
// ===============================================
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
GemmConfig::TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
@@ -133,11 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
@@ -158,8 +107,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ struct fused_moe_args
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
|
||||
@@ -27,6 +27,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
|
||||
@@ -371,6 +371,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
@@ -394,6 +400,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -10,10 +10,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
@@ -21,7 +21,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
|
||||
using Types = GemmTypeConfig<ck_tile::half_t>;
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
|
||||
@@ -5,11 +5,17 @@
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <string_view>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck {
|
||||
|
||||
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
|
||||
{
|
||||
return str.empty() ? h
|
||||
: fnv1a_hash(str.substr(1),
|
||||
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
|
||||
}
|
||||
inline std::string get_device_name()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
@@ -19,37 +25,31 @@ inline std::string get_device_name()
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
const std::string raw_name(props.gcnArchName);
|
||||
|
||||
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
static std::map<std::string, std::string> device_name_map = {
|
||||
{"Ellesmere", "gfx803"},
|
||||
{"Baffin", "gfx803"},
|
||||
{"RacerX", "gfx803"},
|
||||
{"Polaris10", "gfx803"},
|
||||
{"Polaris11", "gfx803"},
|
||||
{"Tonga", "gfx803"},
|
||||
{"Fiji", "gfx803"},
|
||||
{"gfx800", "gfx803"},
|
||||
{"gfx802", "gfx803"},
|
||||
{"gfx804", "gfx803"},
|
||||
{"Vega10", "gfx900"},
|
||||
{"gfx901", "gfx900"},
|
||||
{"10.3.0 Sienna_Cichlid 18", "gfx1030"},
|
||||
};
|
||||
|
||||
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
||||
|
||||
auto match = device_name_map.find(name);
|
||||
if(match != device_name_map.end())
|
||||
return match->second;
|
||||
return name;
|
||||
switch(fnv1a_hash(name))
|
||||
{
|
||||
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
case fnv1a_hash("Ellesmere"):
|
||||
case fnv1a_hash("Baffin"):
|
||||
case fnv1a_hash("RacerX"):
|
||||
case fnv1a_hash("Polaris10"):
|
||||
case fnv1a_hash("Polaris11"):
|
||||
case fnv1a_hash("Tonga"):
|
||||
case fnv1a_hash("Fiji"):
|
||||
case fnv1a_hash("gfx800"):
|
||||
case fnv1a_hash("gfx802"):
|
||||
case fnv1a_hash("gfx804"): return "gfx803";
|
||||
case fnv1a_hash("Vega10"):
|
||||
case fnv1a_hash("gfx901"): return "gfx900";
|
||||
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
|
||||
default: return name;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_xdl_supported()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -306,9 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// printf("bid %d tid %d %f %f\n", blockIdx.x, threadIdx.x,
|
||||
// type_convert<float>(a_thread_buf[I0]),
|
||||
// type_convert<float>(b_thread_bufs[mfma_reg_buf][I0]));
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,10 +118,15 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
@@ -131,19 +137,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
// Force mfma not cross the scaleblock
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
@@ -151,11 +181,116 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
return num_loop == 1 ? TailNumber::Odd : TailNumber::Full;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
|
||||
ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
@@ -169,6 +304,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
@@ -196,6 +332,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
@@ -210,11 +347,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
@@ -223,6 +359,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -231,11 +369,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -243,17 +396,101 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
@@ -261,13 +498,85 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -289,19 +598,70 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -311,46 +671,41 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -371,49 +726,143 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
@@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -360,17 +376,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
|
||||
@@ -453,17 +482,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -528,7 +571,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
@@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
@@ -177,11 +179,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle = 4;
|
||||
// HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle = 4;
|
||||
// HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
@@ -262,6 +264,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
@@ -275,6 +278,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
@@ -302,6 +306,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
@@ -316,12 +321,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceK=1");
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceN=1");
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
@@ -330,6 +337,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -338,11 +347,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -350,8 +374,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
@@ -363,10 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
@@ -409,7 +471,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
@@ -430,19 +495,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -462,11 +531,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -474,7 +559,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -487,7 +571,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
@@ -507,15 +594,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -177,14 +178,57 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -195,7 +239,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always 1
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -208,127 +252,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
@@ -337,6 +267,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
@@ -363,10 +303,11 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
|
||||
// KPerBlock)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
|
||||
@@ -225,7 +225,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -307,7 +307,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
@@ -422,6 +422,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
@@ -459,6 +466,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
@@ -656,40 +664,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
|
||||
constexpr auto a_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(
|
||||
make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_mldslayer_m_ak1,
|
||||
make_tuple(make_pass_through_transform(AK0Number),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc_ak0_m_ak1;
|
||||
return a_lds_block_desc_permuted;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
@@ -791,42 +778,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(LDSTypeB);
|
||||
;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock * NLdsLayer>{}, I1));
|
||||
constexpr auto b_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(
|
||||
make_tuple(Number<NPerBlock>{}, Number<BK0Number>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number<NLdsLayer>{})),
|
||||
make_pass_through_transform(Number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(make_pass_through_transform(BK0Number),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(Number<NPerBlock / NLdsLayer>{}, Number<NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_lds_block_desc_bk0_n_bk1;
|
||||
return b_lds_block_desc_permuted;
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
@@ -992,7 +956,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
@@ -1009,7 +974,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
@@ -1357,28 +1323,39 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
const index_t ScaleSliceSizeM = 1;
|
||||
const index_t ScaleSliceSizeN = 1;
|
||||
const index_t ScaleSliceSizeK = 1;
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
// ScaleSliceSizeK is last dimension in A/B scale for vector memory access
|
||||
// ScaleSliceSizeK is first dimension in C scale for packed math
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
auto a_thread_offset =
|
||||
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<ScaleSliceSizeK>{}, Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeN>{}));
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleType,
|
||||
AScaleType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0));
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
@@ -1388,17 +1365,21 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
|
||||
constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
constexpr auto a_scale_thread_slice_copy_step =
|
||||
make_tuple(make_multi_index(MWaves * MPerXdl, 0),
|
||||
make_multi_index(-MPerBlock, 0),
|
||||
make_multi_index(-MPerBlock, ScaleSliceSizeK));
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
@@ -1411,6 +1392,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
c_scale_thread_desc,
|
||||
c_thread_buf,
|
||||
|
||||
a_scale_grid_desc_am_ak,
|
||||
@@ -1425,8 +1408,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_slice_copy_step,
|
||||
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -1437,23 +1419,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
// transposed XDL
|
||||
// // TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
// // TODO: hacky, fix it!
|
||||
// only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -1462,24 +1445,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M2)), // M2 = MPerXdl
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
N2, // N2 * N3 * N4 = NPerXdl
|
||||
N3,
|
||||
N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -1489,57 +1472,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
N2,
|
||||
I1,
|
||||
N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3],
|
||||
n_thread_data_on_block_idx[I4]),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
@@ -1621,18 +1604,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
N2,
|
||||
1,
|
||||
N4>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -1652,10 +1634,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
|
||||
@@ -58,6 +58,7 @@
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
@@ -29,6 +29,12 @@
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED
|
||||
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define CK_TILE_HOST inline __host__
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
|
||||
204
include/ck_tile/core/utility/env.hpp
Normal file
204
include/ck_tile/core/utility/env.hpp
Normal file
@@ -0,0 +1,204 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
void CK_TILE_ERROR(Args&&... args) noexcept
|
||||
{
|
||||
std::ostringstream oss;
|
||||
(oss << ... << args);
|
||||
std::cerr << "[ERROR] " << oss.str() << std::endl;
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <size_t N>
|
||||
bool is_any_of(const char* const (&names)[N], const std::string& str)
|
||||
{
|
||||
return std::any_of(std::begin(names), std::end(names), [&](const char* inner_str) {
|
||||
return str == inner_str;
|
||||
});
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ParseEnvVal
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct ParseEnvVal<bool>
|
||||
{
|
||||
static bool parse_env_var_value(const char* vp)
|
||||
{
|
||||
std::string value_env_str{vp};
|
||||
|
||||
for(auto& c : value_env_str)
|
||||
{
|
||||
if(std::isalpha(c) != 0)
|
||||
{
|
||||
c = std::tolower(static_cast<unsigned char>(c));
|
||||
}
|
||||
}
|
||||
|
||||
if(is_any_of(enabled_names, value_env_str))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(is_any_of(disabled_names, value_env_str))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Invalid value for env variable");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr const char* enabled_names[] = {"enable", "enabled", "1", "yes", "on", "true"};
|
||||
static constexpr const char* disabled_names[] = {
|
||||
"disable", "disabled", "0", "no", "off", "false"};
|
||||
};
|
||||
|
||||
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
|
||||
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
|
||||
template <>
|
||||
struct ParseEnvVal<uint64_t>
|
||||
{
|
||||
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ParseEnvVal<std::string>
|
||||
{
|
||||
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct EnvVar
|
||||
{
|
||||
private:
|
||||
T value{};
|
||||
bool is_unset = true;
|
||||
|
||||
public:
|
||||
const T& GetValue() const { return value; }
|
||||
|
||||
bool IsUnset() const { return is_unset; }
|
||||
|
||||
void Unset() { is_unset = true; }
|
||||
|
||||
void UpdateValue(const T& val)
|
||||
{
|
||||
is_unset = false;
|
||||
value = val;
|
||||
}
|
||||
|
||||
explicit EnvVar(const char* const name, const T& def_val)
|
||||
{
|
||||
// NOLINTNEXTLINE (concurrency-mt-unsafe)
|
||||
const char* vp = std::getenv(name);
|
||||
if(vp != nullptr) // a value was provided
|
||||
{
|
||||
is_unset = false;
|
||||
value = ParseEnvVal<T>::parse_env_var_value(vp);
|
||||
}
|
||||
else // no value provided, use default value
|
||||
{
|
||||
value = def_val;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // end namespace internal
|
||||
|
||||
// Static inside function hides the variable and provides
|
||||
// thread-safety/locking
|
||||
// Used in global namespace
|
||||
#define CK_TILE_DECLARE_ENV_VAR(name, type, default_val) \
|
||||
namespace ck_tile::env { \
|
||||
struct name \
|
||||
{ \
|
||||
static_assert(std::is_same_v<name, ::ck_tile::env::name>, \
|
||||
"CK_TILE_DECLARE_ENV* must be used in the global namespace"); \
|
||||
using value_type = type; \
|
||||
static ck_tile::internal::EnvVar<type>& Ref() \
|
||||
{ \
|
||||
static ck_tile::internal::EnvVar<type> var{#name, default_val}; \
|
||||
return var; \
|
||||
} \
|
||||
}; \
|
||||
}
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_BOOL(name) CK_TILE_DECLARE_ENV_VAR(name, bool, false)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_UINT64(name) CK_TILE_DECLARE_ENV_VAR(name, uint64_t, 0)
|
||||
|
||||
#define CK_TILE_DECLARE_ENV_VAR_STR(name) CK_TILE_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define CK_TILE_ENV(name) \
|
||||
ck_tile::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
inline const std::string& EnvGetString(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsEnabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsDisabled(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
|
||||
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline uint64_t EnvValue(EnvVar)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
|
||||
return EnvVar::Ref().GetValue();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
inline bool EnvIsUnset(EnvVar)
|
||||
{
|
||||
return EnvVar::Ref().IsUnset();
|
||||
}
|
||||
|
||||
template <class EnvVar>
|
||||
void EnvUnset(EnvVar)
|
||||
{
|
||||
EnvVar::Ref().Unset();
|
||||
}
|
||||
|
||||
/// Updates the cached value of an environment variable
|
||||
template <typename EnvVar, typename ValueType>
|
||||
void UpdateEnvVar(EnvVar, const ValueType& val)
|
||||
{
|
||||
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
|
||||
EnvVar::Ref().UpdateValue(val);
|
||||
}
|
||||
|
||||
template <typename EnvVar>
|
||||
void UpdateEnvVar(EnvVar, const std::string_view& val)
|
||||
{
|
||||
EnvVar::Ref().UpdateValue(
|
||||
ck_tile::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(
|
||||
val.data()));
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -68,16 +68,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
using AWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::AWarpDstrEncoding{}))>;
|
||||
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
|
||||
typename WarpGemm::BWarpDstrEncoding{}))>;
|
||||
|
||||
using AWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
AWarpTileDistr{}))>;
|
||||
using BWarpTile = remove_cvref_t<decltype(make_static_distributed_tensor<ComputeDataType>(
|
||||
BWarpTileDistr{}))>;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
@@ -108,6 +98,25 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static constexpr auto Scheduler = Traits::Scheduler;
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
static constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
@@ -116,18 +125,65 @@ struct BlockUniversalGemmAsBsCr
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window,
|
||||
WarpTile& warp_tile)
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
{
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
constexpr index_t thread_buffer_size =
|
||||
Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
|
||||
using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
@@ -144,6 +200,17 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
|
||||
{
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
@@ -158,114 +225,39 @@ struct BlockUniversalGemmAsBsCr
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
// hot loop:
|
||||
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
AWarpTensor a_warp_tile;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
BWarpTensor b_warp_tile;
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -275,7 +267,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -291,149 +283,68 @@ struct BlockUniversalGemmAsBsCr
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
|
||||
{
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
// TODO: I don't have to move 0,0 window!
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
|
||||
[[maybe_unused]] ASmemBlockWindow& a_block_window,
|
||||
[[maybe_unused]] BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor-
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
@@ -441,9 +352,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kIter],
|
||||
b_warp_tiles_[nIter][kIter]);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
@@ -468,126 +377,53 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
|
||||
MIterPerWarp>
|
||||
a_warp_tiles_;
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
|
||||
NIterPerWarp>
|
||||
b_warp_tiles_;
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
static_assert(
|
||||
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
|
||||
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
|
||||
"MPerBlock, NPerBlock, KPerBlock defined in "
|
||||
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
|
||||
std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
|
||||
"The ADataType and BDataType as defined in "
|
||||
"traits should be the same as correspoinding block window data type!");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
|
||||
|
||||
// TODO: refactor warp_window tile type to class member as it should be
|
||||
// compile-time known information.
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
a_block_window.get_window_origin() +
|
||||
multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
|
||||
AWarpWindow::get_num_of_dimension(),
|
||||
"AWarpWindow number of dimensions must be equal to "
|
||||
"AWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::AWarpTile::get_lengths() ==
|
||||
AWarpWindow{}.get_window_lengths(),
|
||||
"AWarpWindow lengths must be equal to AWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows;
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
a_lds_load_tile_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window.get_window_origin() +
|
||||
multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{}),
|
||||
{0, KIdx * KPerInnerLoop},
|
||||
b_lds_load_tile_distr);
|
||||
|
||||
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
|
||||
|
||||
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
|
||||
BWarpWindow::get_num_of_dimension(),
|
||||
"BWarpWindow number of dimensions must be equal to "
|
||||
"BWarpTile number of dimensions!");
|
||||
static_assert(GemmTraits::BWarpTile::get_lengths() ==
|
||||
BWarpWindow{}.get_window_lengths(),
|
||||
"BWarpWindow lengths must be equal to BWarpTile lengths!");
|
||||
|
||||
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * GemmTraits::MPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * GemmTraits::NPerBlockPerIter,
|
||||
kIter * GemmTraits::KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// TODO check if a_warp_tiles has same desc as a_warp_window
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_windows(mIter)(kIter),
|
||||
a_warp_tiles_(mIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter));
|
||||
}
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_windows(nIter)(kIter),
|
||||
b_warp_tiles_(nIter)(kIter));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
|
||||
}
|
||||
});
|
||||
});
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(a_warp_tile_, a_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(a_warp_tile_, a_lds_gemm_window);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
load_interleaved_pk_type(b_warp_tile_, b_block_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_tile(b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -600,13 +436,6 @@ struct BlockUniversalGemmAsBsCr
|
||||
"The CDataType as defined in traits should be the same as correspoinding "
|
||||
"C block tensor data type!");
|
||||
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KRepeat, 1>{}([&](auto kIter) {
|
||||
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
|
||||
@@ -626,7 +455,21 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// read C warp tensor from C block tensor-
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
@@ -651,9 +494,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor,
|
||||
a_warp_tiles_[mIter][kInnerIter],
|
||||
b_warp_tiles_[nIter][kInnerIter]);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
|
||||
@@ -129,34 +129,34 @@ struct GemmKernel
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead;
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * KRead * kargs.stride_A;
|
||||
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead * kargs.stride_B;
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * KRead;
|
||||
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
{
|
||||
splitted_k = KRead;
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,23 +172,32 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -196,14 +205,19 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -212,29 +226,40 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
std::cerr << "Can't support K that is not a multiple of KPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -243,14 +268,19 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
std::cerr << "Can't support N that is not a multiple of NPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
{
|
||||
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -258,14 +288,19 @@ struct GemmKernel
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
std::cerr << "Can't support M that is not a multiple of MPerBlock"
|
||||
" without padding!"
|
||||
<< std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
{
|
||||
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -279,6 +314,7 @@ struct GemmKernel
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -303,21 +339,63 @@ struct GemmKernel
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
make_tuple(kargs.N * K1, K1, I1),
|
||||
number<VectorSizeB>{},
|
||||
number<1>{});
|
||||
const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(kargs.N)),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -488,7 +566,8 @@ struct GemmKernel
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -539,7 +618,8 @@ struct GemmKernel
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -558,7 +638,8 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
@@ -572,11 +653,11 @@ struct GemmKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
|
||||
if(kargs.k_batch == 1)
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
@@ -590,17 +671,8 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Do not compile in case where we have unsupported
|
||||
// VectorSizeC & data type configuration.
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
|
||||
b_ptr,
|
||||
@@ -612,7 +684,18 @@ struct GemmKernel
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
|
||||
@@ -68,9 +68,10 @@ struct GemmPipelineAgBgCrImplBase
|
||||
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView>
|
||||
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view) const
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
@@ -88,17 +89,21 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
ALdsLoadTileDistr{});
|
||||
|
||||
return make_tuple(std::move(a_copy_dram_window),
|
||||
std::move(a_copy_lds_window),
|
||||
std::move(a_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <typename BDramBlockWindowTmp, typename BLdsTensorView>
|
||||
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view) const
|
||||
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
@@ -117,8 +122,11 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
BLdsLoadTileDistr{});
|
||||
|
||||
return make_tuple(std::move(b_copy_dram_window),
|
||||
std::move(b_copy_lds_window),
|
||||
|
||||
@@ -77,6 +77,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
@@ -114,11 +117,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
@@ -174,11 +177,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
|
||||
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
@@ -346,17 +349,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
// A/B tiles in LDS
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -86,6 +86,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
@@ -129,6 +129,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
@@ -215,10 +218,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto& a_lds_block = ab_lds_blocks.at(I0{});
|
||||
auto& b_lds_block = ab_lds_blocks.at(I1{});
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
|
||||
auto a_windows =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto& a_copy_dram_window = a_windows.at(I0{});
|
||||
auto& a_copy_lds_window = a_windows.at(I1{});
|
||||
auto& a_lds_gemm_window = a_windows.at(I2{});
|
||||
@@ -226,7 +236,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
|
||||
auto b_windows =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto& b_copy_dram_window = b_windows.at(I0{});
|
||||
auto& b_copy_lds_window = b_windows.at(I1{});
|
||||
auto& b_lds_gemm_window = b_windows.at(I2{});
|
||||
@@ -493,10 +504,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
auto& a_lds_block = ab_lds_blocks.at(I0{});
|
||||
auto& b_lds_block = ab_lds_blocks.at(I1{});
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
|
||||
auto a_windows =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto& a_copy_dram_window = a_windows.at(I0{});
|
||||
auto& a_copy_lds_window = a_windows.at(I1{});
|
||||
auto& a_lds_gemm_window = a_windows.at(I2{});
|
||||
@@ -504,7 +522,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
|
||||
auto b_windows =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
auto& b_copy_dram_window = b_windows.at(I0{});
|
||||
auto& b_copy_lds_window = b_windows.at(I1{});
|
||||
auto& b_lds_gemm_window = b_windows.at(I2{});
|
||||
|
||||
@@ -36,6 +36,9 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
@@ -125,13 +128,25 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_lds_load_tile_distr);
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
@@ -31,6 +31,9 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -122,17 +125,29 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_lds_load_tile_distr);
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_lds_load_tile_distr);
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
|
||||
|
||||
|
||||
@@ -8,7 +8,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
|
||||
template <typename BlockTile_,
|
||||
typename BlockWarps_,
|
||||
typename WarpTile_,
|
||||
bool PermuteA_ = false,
|
||||
bool PermuteB_ = false>
|
||||
struct TileGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
@@ -21,6 +25,9 @@ struct TileGemmShape
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr bool PermuteA = PermuteA_;
|
||||
static constexpr bool PermuteB = PermuteB_;
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -82,61 +82,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -163,7 +109,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -180,7 +126,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -198,20 +144,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,13 @@ set(GEMM_AB_SCALE_INSTANCES)
|
||||
list(APPEND GEMM_AB_SCALE_INSTANCES
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})
|
||||
|
||||
@@ -34,49 +34,50 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances = std::tuple<
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
// Spill in current compiler
|
||||
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances = std::tuple<
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
// Memory friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>
|
||||
// Memory friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmDefault>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmKPadding>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmKPadding>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
|
||||
GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -32,6 +32,7 @@ enum struct GemmDataType
|
||||
enum struct ScaleBlockTile
|
||||
{
|
||||
Tile_128_128_128, // 0
|
||||
Tile_1_128_128, // 1
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_ab_scale"
|
||||
@@ -49,7 +50,8 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128];\n");
|
||||
printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128]; 1: ScaleBlockM/N/K = "
|
||||
"[1, 128, 128];\n");
|
||||
printf("arg5: verification (0: no; 1: yes)\n");
|
||||
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg7: print tensor value (0: no; 1: yes)\n");
|
||||
@@ -155,7 +157,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
|
||||
scale_block_tile == ScaleBlockTile::Tile_128_128_128)
|
||||
scale_block_tile == ScaleBlockTile::Tile_1_128_128)
|
||||
{
|
||||
return profile(F8{},
|
||||
F32{},
|
||||
@@ -164,7 +166,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
F8{},
|
||||
F32{},
|
||||
BF16{},
|
||||
ck::Number<128>{},
|
||||
ck::Number<1>{},
|
||||
ck::Number<128>{},
|
||||
ck::Number<128>{},
|
||||
Row{},
|
||||
|
||||
@@ -71,7 +71,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
|
||||
// values.
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
|
||||
Reference in New Issue
Block a user