## generate cuda2hip.h from list of keywords.

# ~ ###To extract keyworkds from sourcefile, use this:

# ~ import re
# ~ import sys

# ~ def extract_cuda_keywords(text):
	# ~ a = re.findall(r"cuda[A-Z][^\s\(\);,]*", text)
	# ~ a += re.findall(r"nvrtc[A-Z][^\s\(\);,]*", text)
	# ~ a += re.findall(r"cu[A-Z][^\s\(\);,]*", text)
	# ~ a += re.findall(r"CU[a-z][^\s\(\);,]*", text)
	# ~ return a

# ~ for fname in sys.argv[1:]:
	# ~ f = open(fname)
	# ~ text = f.read()
	# ~ print("\n".join(extract_cuda_keywords(text)))


cuda = """
cudaError_t
cudaEventCreate
cudaEventCreateWithFlags
cudaEventDestroy
cudaEventDisableTiming
cudaEventElapsedTime
cudaEventRecord
cudaEventSynchronize
cudaEvent_t
cudaFree
cudaGetDevice
cudaGetDeviceCount
cudaGetDeviceProperties
cudaGetErrorString
cudaGetLastError
cudaMalloc
cudaMemcpy
cudaMemcpyAsync
cudaMemcpyDeviceToHost
cudaMemcpyHostToDevice
cudaMemsetAsync
cudaSetDevice
cudaStreamCreateWithFlags
cudaStreamDestroy
cudaStreamNonBlocking
cudaStreamWaitEvent
cudaStream_t
cudaSuccess
"""

cu = """
cuDeviceGet
cuModuleGetFunction
cuModuleLoadDataEx
cufftDestroy
cufftDoubleComplex
cufftExecZ2Z
cufftGetSize
cufftHandle
cufftPlanMany
cufftResult
cufftSetStream
cufftType
"""

nvrtc = """
nvrtcAddNameExpression
nvrtcCompileProgram
nvrtcCreateProgram
nvrtcDestroyProgram
nvrtcGetErrorString
nvrtcGetLoweredName
nvrtcGetProgramLog
nvrtcGetProgramLogSize
nvrtcProgram
nvrtcResult
"""

irregular = """
nvrtcGetPTX hiprtcGetCode
nvrtcGetPTXSize hiprtcGetCodeSize
cuLaunchKernel hipModuleLaunchKernel
NVRTC_SUCCESS HIPRTC_SUCCESS
CUDA_SUCCESS hipSuccess
CUdevice hipDevice_t
CUfunction hipFunction_t
CUmodule hipModule_t
CUresult hipError_t
cudaDeviceProp hipDeviceProp_t
cudaMallocHost hipHostMalloc
cudaFreeHost hipHostFree
CUFFT_D2Z HIPFFT_D2Z
CUFFT_FORWARD HIPFFT_FORWARD
CUFFT_INVERSE HIPFFT_BACKWARD
CUFFT_SUCCESS HIPFFT_SUCCESS
CUFFT_Z2Z HIPFFT_Z2Z
"""


print("/// DO NOT MODIFY DIRECTLY! file generated by cuda2hip_gen.py")
print("/// this files simply translates cuda (nvidia) API to HIP (amd) API")

for x in cuda.split():
	print("#define "+x+x.replace('cuda','\t hip',1))
for x in cu.split():
	print("#define "+x+x.replace('cu','\t hip',1))
for x in nvrtc.split():
	print("#define "+x+x.replace('nvrtc','\t hiprtc',1))
irreg = irregular.split()
for i in range(0,len(irreg),2):
	print("#define "+irreg[i]+'\t '+irreg[i+1])
	
