Files
calibresync/sftp.py
T
2026-05-10 17:49:42 +02:00

146 lines
4.6 KiB
Python

import io
import logging
import shlex
import socket
import time
from dataclasses import dataclass
from pathlib import Path
import paramiko
CONNECT_TIMEOUT = 30 # seconds for TCP + SSH handshake
import db
from config import SFTPConfig
log = logging.getLogger(__name__)
_KEY_CLASSES = [
paramiko.Ed25519Key,
paramiko.RSAKey,
paramiko.ECDSAKey,
]
@dataclass
class RemoteZip:
remote_path: str
file_size: int
def _load_private_key(pem: str) -> paramiko.PKey:
for cls in _KEY_CLASSES:
try:
return cls.from_private_key(io.StringIO(pem))
except Exception:
continue
raise ValueError("Could not parse private key — unsupported format or bad PEM data")
def get_key_fingerprint(pem: str) -> str | None:
if not pem.strip():
return None
try:
key = _load_private_key(pem)
fp = ":".join(f"{b:02x}" for b in key.get_fingerprint())
return f"{key.get_name()} MD5:{fp}"
except Exception as e:
return f"Invalid key: {e}"
def _make_transport(cfg: SFTPConfig) -> paramiko.Transport:
log.debug("Opening TCP connection to %s:%s", cfg.host, cfg.port)
sock = socket.create_connection((cfg.host, cfg.port), timeout=CONNECT_TIMEOUT)
transport = paramiko.Transport(sock)
transport.banner_timeout = CONNECT_TIMEOUT
transport.handshake_timeout = CONNECT_TIMEOUT
log.debug("Starting SSH handshake")
if cfg.auth_method == "key" and cfg.key:
key = _load_private_key(cfg.key)
transport.connect(username=cfg.user, pkey=key)
else:
transport.connect(username=cfg.user, password=cfg.password)
log.debug("SSH authenticated")
return transport
def test_connection(cfg: SFTPConfig) -> tuple[bool, str]:
try:
transport = _make_transport(cfg)
sftp = paramiko.SFTPClient.from_transport(transport)
try:
entries = sftp.listdir(cfg.remote_path)
zip_count = sum(1 for e in entries if e.lower().endswith(".zip"))
return True, f"Connected to {cfg.host}. {len(entries)} item(s) in {cfg.remote_path} ({zip_count} zip file(s) at top level)."
finally:
sftp.close()
transport.close()
except Exception as e:
return False, str(e)
def list_new_zips(cfg: SFTPConfig, max_results: int | None = None) -> list[RemoteZip]:
transport = _make_transport(cfg)
try:
t0 = time.monotonic()
all_zips = _find_remote_zips(transport, cfg.remote_path)
log.info("Remote find done in %.1fs — %d zip(s) found", time.monotonic() - t0, len(all_zips))
t1 = time.monotonic()
processed = db.get_all_processed_paths()
log.info("DB lookup done in %.1fs — %d path(s) already processed", time.monotonic() - t1, len(processed))
new_zips: list[RemoteZip] = []
for zip_info in all_zips:
if zip_info.remote_path not in processed:
new_zips.append(zip_info)
if max_results and len(new_zips) >= max_results:
log.info("Reached limit of %d", max_results)
break
log.info("%d new zip(s) to process", len(new_zips))
return new_zips
finally:
transport.close()
def download(cfg: SFTPConfig, remote_zip: RemoteZip, dest_dir: str) -> Path:
dest = Path(dest_dir)
dest.mkdir(parents=True, exist_ok=True)
local_path = dest / Path(remote_zip.remote_path).name
transport = _make_transport(cfg)
sftp = paramiko.SFTPClient.from_transport(transport)
try:
log.info("Downloading %s%s", remote_zip.remote_path, local_path)
sftp.get(remote_zip.remote_path, str(local_path))
finally:
sftp.close()
transport.close()
return local_path
def _find_remote_zips(transport: paramiko.Transport, remote_path: str) -> list[RemoteZip]:
"""Single SSH exec: find all .zip files server-side. Vastly faster than per-directory SFTP calls."""
channel = transport.open_session()
cmd = f"find {shlex.quote(remote_path)} -type f -iname '*.zip' -printf '%s\\t%p\\n'"
log.info("Running remote find under %s ...", remote_path)
channel.exec_command(cmd)
zips: list[RemoteZip] = []
for line in channel.makefile("r", -1):
line = line.rstrip("\n")
if "\t" not in line:
continue
size_str, path = line.split("\t", 1)
try:
zips.append(RemoteZip(remote_path=path, file_size=int(size_str)))
except ValueError:
continue
stderr_out = channel.makefile_stderr("r", -1).read().strip()
if stderr_out:
log.warning("find stderr: %s", stderr_out[:500])
channel.recv_exit_status()
channel.close()
return zips