210 lines
7.5 KiB
Python
210 lines
7.5 KiB
Python
import io
|
|
import logging
|
|
import shlex
|
|
import socket
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
|
|
import paramiko
|
|
|
|
CONNECT_TIMEOUT = 30 # seconds for TCP + SSH handshake
|
|
|
|
import db
|
|
from config import SFTPConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_KEY_CLASSES = [
|
|
paramiko.Ed25519Key,
|
|
paramiko.RSAKey,
|
|
paramiko.ECDSAKey,
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class RemoteZip:
|
|
remote_path: str
|
|
file_size: int
|
|
|
|
|
|
def _load_private_key(pem: str) -> paramiko.PKey:
|
|
for cls in _KEY_CLASSES:
|
|
try:
|
|
return cls.from_private_key(io.StringIO(pem))
|
|
except Exception:
|
|
continue
|
|
raise ValueError("Could not parse private key — unsupported format or bad PEM data")
|
|
|
|
|
|
def get_key_fingerprint(pem: str) -> str | None:
|
|
if not pem.strip():
|
|
return None
|
|
try:
|
|
key = _load_private_key(pem)
|
|
fp = ":".join(f"{b:02x}" for b in key.get_fingerprint())
|
|
return f"{key.get_name()} MD5:{fp}"
|
|
except Exception as e:
|
|
return f"Invalid key: {e}"
|
|
|
|
|
|
def _make_transport(cfg: SFTPConfig) -> paramiko.Transport:
|
|
log.debug("Opening TCP connection to %s:%s", cfg.host, cfg.port)
|
|
sock = socket.create_connection((cfg.host, cfg.port), timeout=CONNECT_TIMEOUT)
|
|
transport = paramiko.Transport(sock)
|
|
transport.banner_timeout = CONNECT_TIMEOUT
|
|
transport.handshake_timeout = CONNECT_TIMEOUT
|
|
log.debug("Starting SSH handshake")
|
|
if cfg.auth_method == "key" and cfg.key:
|
|
key = _load_private_key(cfg.key)
|
|
transport.connect(username=cfg.user, pkey=key)
|
|
else:
|
|
transport.connect(username=cfg.user, password=cfg.password)
|
|
log.debug("SSH authenticated")
|
|
return transport
|
|
|
|
|
|
def test_connection(cfg: SFTPConfig) -> tuple[bool, str]:
|
|
try:
|
|
transport = _make_transport(cfg)
|
|
sftp = paramiko.SFTPClient.from_transport(transport)
|
|
try:
|
|
entries = sftp.listdir(cfg.remote_path)
|
|
zip_count = sum(1 for e in entries if e.lower().endswith(".zip"))
|
|
return True, f"Connected to {cfg.host}. {len(entries)} item(s) in {cfg.remote_path} ({zip_count} zip file(s) at top level)."
|
|
finally:
|
|
sftp.close()
|
|
transport.close()
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
|
|
def list_new_zips(cfg: SFTPConfig, max_results: int | None = None) -> list[RemoteZip]:
|
|
last_scan = db.get_setting("remote_cache_last_scan")
|
|
|
|
transport = _make_transport(cfg)
|
|
try:
|
|
t0 = time.monotonic()
|
|
if last_scan:
|
|
# Fast incremental: prune directories not modified since last scan.
|
|
# Adding a file/dir to a directory updates that directory's mtime,
|
|
# so we safely skip entire subtrees that haven't changed.
|
|
cutoff = _scan_cutoff(last_scan)
|
|
log.info("Incremental scan — looking for directories modified since %s ...", cutoff)
|
|
new_remote = _find_remote_zips(transport, cfg.remote_path, newer_than=cutoff)
|
|
log.info("Incremental scan done in %.1fs — %d new zip(s) on remote", time.monotonic() - t0, len(new_remote))
|
|
else:
|
|
log.info("First run — full remote scan (may take several minutes for large trees) ...")
|
|
new_remote = _find_remote_zips(transport, cfg.remote_path)
|
|
log.info("Full scan done in %.1fs — %d zip(s) found", time.monotonic() - t0, len(new_remote))
|
|
finally:
|
|
transport.close()
|
|
|
|
# Record scan time, then update cache with any new entries found
|
|
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
|
db.set_setting("remote_cache_last_scan", now_str)
|
|
if new_remote:
|
|
db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in new_remote])
|
|
log.info("Cache updated with %d new entry(ies)", len(new_remote))
|
|
|
|
# Filter full cache against already-processed paths
|
|
t1 = time.monotonic()
|
|
all_cached = db.get_remote_zip_cache()
|
|
processed = db.get_all_processed_paths()
|
|
log.info("DB lookup done in %.1fs — cache: %d, processed: %d", time.monotonic() - t1, len(all_cached), len(processed))
|
|
|
|
new_zips: list[RemoteZip] = []
|
|
for path, size in all_cached:
|
|
if path not in processed:
|
|
new_zips.append(RemoteZip(remote_path=path, file_size=size))
|
|
if max_results and len(new_zips) >= max_results:
|
|
break
|
|
|
|
log.info("%d zip(s) to process", len(new_zips))
|
|
return new_zips
|
|
|
|
|
|
def refresh_remote_zip_cache(cfg: SFTPConfig) -> int:
|
|
"""Force a full remote scan, replacing the entire cache. Used by the manual rescan button."""
|
|
log.info("Forced full remote cache refresh ...")
|
|
t0 = time.monotonic()
|
|
transport = _make_transport(cfg)
|
|
try:
|
|
all_zips = _find_remote_zips(transport, cfg.remote_path)
|
|
finally:
|
|
transport.close()
|
|
log.info("Full scan done in %.1fs — %d zip(s)", time.monotonic() - t0, len(all_zips))
|
|
db.clear_remote_zip_cache()
|
|
db.upsert_remote_zip_cache([(z.remote_path, z.file_size) for z in all_zips])
|
|
now_str = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
|
db.set_setting("remote_cache_last_scan", now_str)
|
|
log.info("Cache refreshed: %d zip(s) stored", len(all_zips))
|
|
return len(all_zips)
|
|
|
|
|
|
def download(cfg: SFTPConfig, remote_zip: RemoteZip, dest_dir: str) -> Path:
|
|
dest = Path(dest_dir)
|
|
dest.mkdir(parents=True, exist_ok=True)
|
|
local_path = dest / Path(remote_zip.remote_path).name
|
|
transport = _make_transport(cfg)
|
|
sftp = paramiko.SFTPClient.from_transport(transport)
|
|
try:
|
|
log.info("Downloading %s → %s", remote_zip.remote_path, local_path)
|
|
sftp.get(remote_zip.remote_path, str(local_path))
|
|
finally:
|
|
sftp.close()
|
|
transport.close()
|
|
return local_path
|
|
|
|
|
|
def _find_remote_zips(transport: paramiko.Transport, remote_path: str, newer_than: str | None = None) -> list[RemoteZip]:
|
|
"""Run find on the remote host, streaming results with progress logging every 30 s."""
|
|
channel = transport.open_session()
|
|
|
|
if newer_than:
|
|
# Prune entire directory subtrees whose mtime predates the cutoff.
|
|
# A directory's mtime is updated when entries are added inside it,
|
|
# so old-mtime dirs are guaranteed to contain no new files.
|
|
cmd = (
|
|
f"find {shlex.quote(remote_path)}"
|
|
f" \\( -type d ! -newermt {shlex.quote(newer_than)} -prune \\)"
|
|
f" -o \\( -type f -iname '*.zip' -printf '%s\\t%p\\n' \\)"
|
|
)
|
|
else:
|
|
cmd = f"find {shlex.quote(remote_path)} -type f -iname '*.zip' -printf '%s\\t%p\\n'"
|
|
|
|
channel.exec_command(cmd)
|
|
|
|
zips: list[RemoteZip] = []
|
|
last_log = time.monotonic()
|
|
|
|
for line in channel.makefile("r", -1):
|
|
line = line.rstrip("\n")
|
|
if "\t" not in line:
|
|
continue
|
|
size_str, path = line.split("\t", 1)
|
|
try:
|
|
zips.append(RemoteZip(remote_path=path, file_size=int(size_str)))
|
|
except ValueError:
|
|
continue
|
|
|
|
now = time.monotonic()
|
|
if now - last_log >= 30:
|
|
log.info("Find in progress: %d zip(s) found so far ...", len(zips))
|
|
last_log = now
|
|
|
|
stderr_out = channel.makefile_stderr("r", -1).read().strip()
|
|
if stderr_out:
|
|
log.warning("find stderr: %s", stderr_out[:500])
|
|
channel.recv_exit_status()
|
|
channel.close()
|
|
return zips
|
|
|
|
|
|
def _scan_cutoff(last_scan: str) -> str:
|
|
"""Subtract 5-minute safety buffer from last-scan timestamp to handle clock skew."""
|
|
dt = datetime.strptime(last_scan, "%Y-%m-%d %H:%M:%S").replace(tzinfo=timezone.utc)
|
|
dt -= timedelta(minutes=5)
|
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|