Files
calibresync/sftp.py
T
2026-05-10 18:02:06 +02:00

210 lines
7.5 KiB
Python

import io
import logging
import shlex
import socket
import time
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
import paramiko
CONNECT_TIMEOUT = 30 # seconds for TCP + SSH handshake
import db
from config import SFTPConfig
log = logging.getLogger(__name__)
_KEY_CLASSES = [
paramiko.Ed25519Key,
paramiko.RSAKey,
paramiko.ECDSAKey,
]
@dataclass
class RemoteZip:
remote_path: str
file_size: int
def _load_private_key(pem: str) -> paramiko.PKey:
for cls in _KEY_CLASSES:
try:
return cls.from_private_key(io.StringIO(pem))
except Exception:
continue
raise ValueError("Could not parse private key — unsupported format or bad PEM data")
def get_key_fingerprint(pem: str) -> str | None:
if not pem.strip():
return None
try:
key = _load_private_key(pem)
fp = ":".join(f"{b:02x}" for b in key.get_fingerprint())
return f"{key.get_name()} MD5:{fp}"
except Exception as e:
return f"Invalid key: {e}"
def _make_transport(cfg: SFTPConfig) -> paramiko.Transport:
log.debug("Opening TCP connection to %s:%s", cfg.host, cfg.port)
sock = socket.create_connection((cfg.host, cfg.port), timeout=CONNECT_TIMEOUT)
transport = paramiko.Transport(sock)
transport.banner_timeout = CONNECT_TIMEOUT
transport.handshake_timeout = CONNECT_TIMEOUT
log.debug("Starting SSH handshake")
if cfg.auth_method == "key" and cfg.key:
key = _load_private_key(cfg.key)
transport.connect(username=cfg.user, pkey=key)
else:
transport.connect(username=cfg.user, password=cfg.password)
log.debug("SSH authenticated")
return transport
def test_connection(cfg: SFTPConfig) -> tuple[bool, str]:
try:
transport = _make_transport(cfg)
sftp = paramiko.SFTPClient.from_transport(transport)
try:
entries = sftp.listdir(cfg.remote_path)
zip_count = sum(1 for e in entries if e.lower().endswith(".zip"))
return True, f"Connected to {cfg.host}. {len(entries)} item(s) in {cfg.remote_path} ({zip_count} zip file(s) at top level)."
finally:
sftp.close()
transport.close()
except Exception as e:
return False, str(e)
def list_new_zips(cfg: SFTPConfig, max_results: int | None = None) -> list[RemoteZip]:
last_scan = db.get_setting("remote_cache_last_scan")
transport = _make_transport(cfg)
try:
t0 = time.monotonic()
if last_scan:
# Fast incremental: prune directories not modified since last scan.
# Adding a file/dir to a directory updates that directory's mtime,
# so we safely skip entire subtrees that haven't changed.
cutoff = _scan_cutoff(last_scan)
log.info("Incremental scan — looking for directories modified since %s ...", cutoff)
new_remote = _find_remote_zips(transport, cfg.remote_path, newer_than=cutoff)
log.info("Incremental scan done in %.1fs — %d new zip(s) on remote", time.monotonic() - t0, len(new_remote))
else:
log.info("First run — full remote scan (may take several minutes for large trees) ...")
new_remote = _find_remote_zips(transport, cfg.remote_path)
log.info("Full scan done in %.1fs — %d zip(s) found", time.monotonic() - t0, len(new_remote))
finally:
transport.close()
# Record scan time, then update cache with any new entries found
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
db.set_setting("remote_cache_last_scan", now_str)
if new_remote:
db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in new_remote])
log.info("Cache updated with %d new entry(ies)", len(new_remote))
# Filter full cache against already-processed paths
t1 = time.monotonic()
all_cached = db.get_remote_zip_cache()
processed = db.get_all_processed_paths()
log.info("DB lookup done in %.1fs — cache: %d, processed: %d", time.monotonic() - t1, len(all_cached), len(processed))
new_zips: list[RemoteZip] = []
for path, size in all_cached:
if path not in processed:
new_zips.append(RemoteZip(remote_path=path, file_size=size))
if max_results and len(new_zips) >= max_results:
break
log.info("%d zip(s) to process", len(new_zips))
return new_zips
def refresh_remote_zip_cache(cfg: SFTPConfig) -> int:
"""Force a full remote scan, replacing the entire cache. Used by the manual rescan button."""
log.info("Forced full remote cache refresh ...")
t0 = time.monotonic()
transport = _make_transport(cfg)
try:
all_zips = _find_remote_zips(transport, cfg.remote_path)
finally:
transport.close()
log.info("Full scan done in %.1fs — %d zip(s)", time.monotonic() - t0, len(all_zips))
db.clear_remote_zip_cache()
db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in all_zips])
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
db.set_setting("remote_cache_last_scan", now_str)
log.info("Cache refreshed: %d zip(s) stored", len(all_zips))
return len(all_zips)
def download(cfg: SFTPConfig, remote_zip: RemoteZip, dest_dir: str) -> Path:
dest = Path(dest_dir)
dest.mkdir(parents=True, exist_ok=True)
local_path = dest / Path(remote_zip.remote_path).name
transport = _make_transport(cfg)
sftp = paramiko.SFTPClient.from_transport(transport)
try:
log.info("Downloading %s%s", remote_zip.remote_path, local_path)
sftp.get(remote_zip.remote_path, str(local_path))
finally:
sftp.close()
transport.close()
return local_path
def _find_remote_zips(transport: paramiko.Transport, remote_path: str, newer_than: str | None = None) -> list[RemoteZip]:
"""Run find on the remote host, streaming results with progress logging every 30 s."""
channel = transport.open_session()
if newer_than:
# Prune entire directory subtrees whose mtime predates the cutoff.
# A directory's mtime is updated when entries are added inside it,
# so old-mtime dirs are guaranteed to contain no new files.
cmd = (
f"find {shlex.quote(remote_path)}"
f" \\( -type d ! -newermt {shlex.quote(newer_than)} -prune \\)"
f" -o \\( -type f -iname '*.zip' -printf '%s\\t%p\\n' \\)"
)
else:
cmd = f"find {shlex.quote(remote_path)} -type f -iname '*.zip' -printf '%s\\t%p\\n'"
channel.exec_command(cmd)
zips: list[RemoteZip] = []
last_log = time.monotonic()
for line in channel.makefile("r", -1):
line = line.rstrip("\n")
if "\t" not in line:
continue
size_str, path = line.split("\t", 1)
try:
zips.append(RemoteZip(remote_path=path, file_size=int(size_str)))
except ValueError:
continue
now = time.monotonic()
if now - last_log >= 30:
log.info("Find in progress: %d zip(s) found so far ...", len(zips))
last_log = now
stderr_out = channel.makefile_stderr("r", -1).read().strip()
if stderr_out:
log.warning("find stderr: %s", stderr_out[:500])
channel.recv_exit_status()
channel.close()
return zips
def _scan_cutoff(last_scan: str) -> str:
"""Subtract 5-minute safety buffer from last-scan timestamp to handle clock skew."""
dt = datetime.strptime(last_scan, "%Y-%m-%d %H:%M:%S").replace(tzinfo=timezone.utc)
dt -= timedelta(minutes=5)
return dt.strftime("%Y-%m-%d %H:%M:%S")