#!/bin/sh
set -euo pipefail

TMPDIR=$(mktemp -d)
cd "$TMPDIR"

# ---------------------------------------------------------------------------
# CUDA detection
# ---------------------------------------------------------------------------
CUDA_PATH=""
CUDA_INCLUDE=""

NVCC_BIN=$(command -v nvcc 2>/dev/null || true)
if [ -n "$NVCC_BIN" ]; then
  NVCC_PREFIX=$(dirname "$(dirname "$NVCC_BIN")")
fi
for cand in /usr/lib/cuda /usr/local/cuda "${NVCC_PREFIX:-}"; do
  [ -z "$cand" ] && continue
  if [ -d "$cand/nvvm" ] && { [ -d "$cand/nvvm/libdevice" ] || [ -e "$cand/nvvm/libdevice" ]; }; then
    CUDA_PATH="$cand"
    break
  fi
done
if [ -n "$CUDA_PATH" ] && [ -d "$CUDA_PATH/include" ]; then
  CUDA_INCLUDE="$CUDA_PATH/include"
elif [ -f /usr/include/cuda_runtime.h ]; then
  CUDA_INCLUDE=/usr/include
fi

if [ -z "$CUDA_PATH" ] || [ -z "$CUDA_INCLUDE" ]; then
  echo "Skipping: CUDA toolkit not detected"
  exit 77
fi

echo "CUDA detected: path=$CUDA_PATH include=$CUDA_INCLUDE"

# ---------------------------------------------------------------------------
# TEST 3 — hipify-clang with real CUDA (requires CUDA install)
# Verify AST-based translation of a __global__ kernel + <<<>>> launch.
# ---------------------------------------------------------------------------
echo "[TEST 3] hipify-clang kernel launch translation..."

cat > sample_clang.cu <<'EOF'
#include <cuda_runtime.h>
__global__ void kernel() {}
int main() { kernel<<<1,1>>>(); return 0; }
EOF

/usr/bin/hipify-clang --hip-kernel-execution-syntax \
  --cuda-path="$CUDA_PATH" \
  -I"$CUDA_INCLUDE" \
  -o sample_clang.hip.cu \
  sample_clang.cu

if grep -q "hipLaunchKernelGGL" sample_clang.hip.cu; then
  echo "[TEST 3] PASS: kernel launch translated to hipLaunchKernelGGL"
  head -n 50 sample_clang.hip.cu
else
  echo "[TEST 3] FAIL: hipLaunchKernelGGL not found in output" >&2
  cat sample_clang.hip.cu >&2
  exit 1
fi

# ---------------------------------------------------------------------------
# TEST 4 — hipify-perl with real CUDA source (requires CUDA install)
# Verify regex-based translation of multiple CUDA API calls.
# ---------------------------------------------------------------------------
echo "[TEST 4] hipify-perl API call translation..."

cat > sample_perl_cuda.cu <<'EOF'
#include <cuda_runtime.h>
int main() {
  void *ptr;
  cudaMalloc(&ptr, 1024);
  cudaMemcpy(ptr, ptr, 1024, cudaMemcpyDeviceToDevice);
  cudaDeviceSynchronize();
  cudaFree(ptr);
  return 0;
}
EOF

/usr/bin/hipify-perl sample_perl_cuda.cu > sample_perl_cuda.hip.cu

FAIL=0
for pair in "cudaMalloc:hipMalloc" "cudaMemcpy:hipMemcpy" "cudaDeviceSynchronize:hipDeviceSynchronize" "cudaFree:hipFree"; do
  cuda_api="${pair%%:*}"
  hip_api="${pair##*:}"
  if grep -q "$hip_api" sample_perl_cuda.hip.cu; then
    echo "[TEST 4] PASS: $cuda_api -> $hip_api"
  else
    echo "[TEST 4] FAIL: $cuda_api was not translated to $hip_api" >&2
    FAIL=1
  fi
done

if [ "$FAIL" -eq 1 ]; then
  cat sample_perl_cuda.hip.cu >&2
  exit 1
fi

echo "All hipify CUDA tests passed."
exit 0
