import io import logging import shlex import socket import time from dataclasses import dataclass from datetime import datetime, timedelta, timezone from pathlib import Path import paramiko CONNECT_TIMEOUT = 30 # seconds for TCP + SSH handshake import db from config import SFTPConfig log = logging.getLogger(__name__) _KEY_CLASSES = [ paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey, ] @dataclass class RemoteZip: remote_path: str file_size: int def _load_private_key(pem: str) -> paramiko.PKey: for cls in _KEY_CLASSES: try: return cls.from_private_key(io.StringIO(pem)) except Exception: continue raise ValueError("Could not parse private key — unsupported format or bad PEM data") def get_key_fingerprint(pem: str) -> str | None: if not pem.strip(): return None try: key = _load_private_key(pem) fp = ":".join(f"{b:02x}" for b in key.get_fingerprint()) return f"{key.get_name()} MD5:{fp}" except Exception as e: return f"Invalid key: {e}" def _make_transport(cfg: SFTPConfig) -> paramiko.Transport: log.debug("Opening TCP connection to %s:%s", cfg.host, cfg.port) sock = socket.create_connection((cfg.host, cfg.port), timeout=CONNECT_TIMEOUT) transport = paramiko.Transport(sock) transport.banner_timeout = CONNECT_TIMEOUT transport.handshake_timeout = CONNECT_TIMEOUT log.debug("Starting SSH handshake") if cfg.auth_method == "key" and cfg.key: key = _load_private_key(cfg.key) transport.connect(username=cfg.user, pkey=key) else: transport.connect(username=cfg.user, password=cfg.password) log.debug("SSH authenticated") return transport def test_connection(cfg: SFTPConfig) -> tuple[bool, str]: try: transport = _make_transport(cfg) sftp = paramiko.SFTPClient.from_transport(transport) try: entries = sftp.listdir(cfg.remote_path) zip_count = sum(1 for e in entries if e.lower().endswith(".zip")) return True, f"Connected to {cfg.host}. {len(entries)} item(s) in {cfg.remote_path} ({zip_count} zip file(s) at top level)." finally: sftp.close() transport.close() except Exception as e: return False, str(e) def list_new_zips(cfg: SFTPConfig, max_results: int | None = None) -> list[RemoteZip]: last_scan = db.get_setting("remote_cache_last_scan") transport = _make_transport(cfg) try: t0 = time.monotonic() if last_scan: # Fast incremental: prune directories not modified since last scan. # Adding a file/dir to a directory updates that directory's mtime, # so we safely skip entire subtrees that haven't changed. cutoff = _scan_cutoff(last_scan) log.info("Incremental scan — looking for directories modified since %s ...", cutoff) new_remote = _find_remote_zips(transport, cfg.remote_path, newer_than=cutoff) log.info("Incremental scan done in %.1fs — %d new zip(s) on remote", time.monotonic() - t0, len(new_remote)) else: log.info("First run — full remote scan (may take several minutes for large trees) ...") new_remote = _find_remote_zips(transport, cfg.remote_path) log.info("Full scan done in %.1fs — %d zip(s) found", time.monotonic() - t0, len(new_remote)) finally: transport.close() # Record scan time, then update cache with any new entries found now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") db.set_setting("remote_cache_last_scan", now_str) if new_remote: db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in new_remote]) log.info("Cache updated with %d new entry(ies)", len(new_remote)) # Filter full cache against already-processed paths t1 = time.monotonic() all_cached = db.get_remote_zip_cache() processed = db.get_all_processed_paths() log.info("DB lookup done in %.1fs — cache: %d, processed: %d", time.monotonic() - t1, len(all_cached), len(processed)) new_zips: list[RemoteZip] = [] for path, size in all_cached: if path not in processed: new_zips.append(RemoteZip(remote_path=path, file_size=size)) if max_results and len(new_zips) >= max_results: break log.info("%d zip(s) to process", len(new_zips)) return new_zips def refresh_remote_zip_cache(cfg: SFTPConfig) -> int: """Force a full remote scan, replacing the entire cache. Used by the manual rescan button.""" log.info("Forced full remote cache refresh ...") t0 = time.monotonic() transport = _make_transport(cfg) try: all_zips = _find_remote_zips(transport, cfg.remote_path) finally: transport.close() log.info("Full scan done in %.1fs — %d zip(s)", time.monotonic() - t0, len(all_zips)) db.clear_remote_zip_cache() db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in all_zips]) now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") db.set_setting("remote_cache_last_scan", now_str) log.info("Cache refreshed: %d zip(s) stored", len(all_zips)) return len(all_zips) def download(cfg: SFTPConfig, remote_zip: RemoteZip, dest_dir: str) -> Path: dest = Path(dest_dir) dest.mkdir(parents=True, exist_ok=True) local_path = dest / Path(remote_zip.remote_path).name transport = _make_transport(cfg) sftp = paramiko.SFTPClient.from_transport(transport) try: log.info("Downloading %s → %s", remote_zip.remote_path, local_path) sftp.get(remote_zip.remote_path, str(local_path)) finally: sftp.close() transport.close() return local_path def _find_remote_zips(transport: paramiko.Transport, remote_path: str, newer_than: str | None = None) -> list[RemoteZip]: """Run find on the remote host, streaming results with progress logging every 30 s.""" channel = transport.open_session() if newer_than: # Prune entire directory subtrees whose mtime predates the cutoff. # A directory's mtime is updated when entries are added inside it, # so old-mtime dirs are guaranteed to contain no new files. cmd = ( f"find {shlex.quote(remote_path)}" f" \\( -type d ! -newermt {shlex.quote(newer_than)} -prune \\)" f" -o \\( -type f -iname '*.zip' -printf '%s\\t%p\\n' \\)" ) else: cmd = f"find {shlex.quote(remote_path)} -type f -iname '*.zip' -printf '%s\\t%p\\n'" channel.exec_command(cmd) zips: list[RemoteZip] = [] last_log = time.monotonic() for line in channel.makefile("r", -1): line = line.rstrip("\n") if "\t" not in line: continue size_str, path = line.split("\t", 1) try: zips.append(RemoteZip(remote_path=path, file_size=int(size_str))) except ValueError: continue now = time.monotonic() if now - last_log >= 30: log.info("Find in progress: %d zip(s) found so far ...", len(zips)) last_log = now stderr_out = channel.makefile_stderr("r", -1).read().strip() if stderr_out: log.warning("find stderr: %s", stderr_out[:500]) channel.recv_exit_status() channel.close() return zips def _scan_cutoff(last_scan: str) -> str: """Subtract 5-minute safety buffer from last-scan timestamp to handle clock skew.""" dt = datetime.strptime(last_scan, "%Y-%m-%d %H:%M:%S").replace(tzinfo=timezone.utc) dt -= timedelta(minutes=5) return dt.strftime("%Y-%m-%d %H:%M:%S")