diff --git a/install.sh b/install.sh index 106dfe8..5e7769f 100644 --- a/install.sh +++ b/install.sh @@ -87,24 +87,42 @@ fi DETECTED_CUDA="$(nvidia-smi | grep 'CUDA Version' | sed 's/.*CUDA Version: *\([0-9.]*\).*/\1/')" log "Detected CUDA capability: ${DETECTED_CUDA}" +# Extract major.minor (e.g., 12.4) +DETECTED_MM="$(echo "$DETECTED_CUDA" | cut -d. -f1,2)" + choose_cuda() { - # Pick highest supported version <= detected CUDA capability - for v in $(printf '%s\n' "${SUPPORTED_CUDA[@]}" | sort -rV); do - if [ "$(printf '%s\n%s\n' "$v" "$DETECTED_CUDA" | sort -V | head -n1)" = "$v" ]; then - echo "$v" - return 0 + local best="" + local best_mm="" + + for v in "${SUPPORTED_CUDA[@]}"; do + v_mm="$(echo "$v" | cut -d. -f1,2)" + + # Only consider versions with same major.minor + if [ "$v_mm" = "$DETECTED_MM" ]; then + # Pick highest patch of that minor + if [ -z "$best" ] || [ "$(printf '%s\n%s\n' "$best" "$v" | sort -V | tail -n1)" = "$v" ]; then + best="$v" + fi fi done + + if [ -n "$best" ]; then + echo "$best" + return 0 + fi + return 1 } if CUDA_VERSION="$(choose_cuda)"; then log "Selected CUDA image version: ${CUDA_VERSION}" else - warn "Could not auto-match CUDA version, defaulting to highest supported" - CUDA_VERSION="$(printf '%s\n' "${SUPPORTED_CUDA[@]}" | sort -V | tail -n1)" + err "No compatible CUDA image found for driver capability ${DETECTED_MM}.x" + err "Supported versions: ${SUPPORTED_CUDA[*]}" + exit 1 fi + IMAGE="${IMAGE_BASE}/${CUDA_VERSION}/tentacle:${IMAGE_TAG}" # -----------------------------