Fixed the Cuda detection and fallback logic

On systems where CUDA_VERSION=12.4 the script was incorrectly labeling it as 12.6 by default
This commit is contained in:
2026-01-22 11:54:36 -05:00
parent b41b7d7420
commit 32e40d3a7f

View File

@@ -88,21 +88,24 @@ fi
DETECTED_CUDA="$(nvidia-smi | grep 'CUDA Version' | sed 's/.*CUDA Version: *\([0-9.]*\).*/\1/')" DETECTED_CUDA="$(nvidia-smi | grep 'CUDA Version' | sed 's/.*CUDA Version: *\([0-9.]*\).*/\1/')"
log "Detected CUDA capability: ${DETECTED_CUDA}" log "Detected CUDA capability: ${DETECTED_CUDA}"
choose_cuda() { choose_cuda() {
for v in "${SUPPORTED_CUDA[@]}"; do # Pick highest supported version <= detected CUDA capability
if [[ "$DETECTED_CUDA" == "$v"* ]]; then for v in $(printf '%s\n' "${SUPPORTED_CUDA[@]}" | sort -rV); do
if [ "$(printf '%s\n%s\n' "$v" "$DETECTED_CUDA" | sort -V | head -n1)" = "$v" ]; then
echo "$v" echo "$v"
return return
fi fi
done done
echo ""
}
# fallback: highest <= detected # fallback: highest <= detected
for v in $(printf '%s\n' "${SUPPORTED_CUDA[@]}" | sort -rV); do if [ -z "$CUDA_VERSION" ]; then
if [[ "$(printf '%s\n%s\n' "$v" "$DETECTED_CUDA" | sort -V | head -n1)" == "$v" ]]; then warn "Could not auto-match CUDA version, defaulting to highest supported"
echo "$v" CUDA_VERSION="$(printf '%s\n' "${SUPPORTED_CUDA[@]}" | sort -V | tail -n1)"
return else
log "Selected CUDA image version: ${CUDA_VERSION}"
fi fi
done
echo "" echo ""
} }