125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
import io
|
|
import logging
|
|
import posixpath
|
|
import socket
|
|
import stat
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import paramiko
|
|
|
|
CONNECT_TIMEOUT = 30 # seconds for TCP + SSH handshake
|
|
|
|
import db
|
|
from config import SFTPConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_KEY_CLASSES = [
|
|
paramiko.Ed25519Key,
|
|
paramiko.RSAKey,
|
|
paramiko.ECDSAKey,
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class RemoteZip:
|
|
remote_path: str
|
|
file_size: int
|
|
|
|
|
|
def _load_private_key(pem: str) -> paramiko.PKey:
|
|
for cls in _KEY_CLASSES:
|
|
try:
|
|
return cls.from_private_key(io.StringIO(pem))
|
|
except Exception:
|
|
continue
|
|
raise ValueError("Could not parse private key — unsupported format or bad PEM data")
|
|
|
|
|
|
def get_key_fingerprint(pem: str) -> str | None:
|
|
if not pem.strip():
|
|
return None
|
|
try:
|
|
key = _load_private_key(pem)
|
|
fp = ":".join(f"{b:02x}" for b in key.get_fingerprint())
|
|
return f"{key.get_name()} MD5:{fp}"
|
|
except Exception as e:
|
|
return f"Invalid key: {e}"
|
|
|
|
|
|
def _make_transport(cfg: SFTPConfig) -> paramiko.Transport:
|
|
log.debug("Opening TCP connection to %s:%s", cfg.host, cfg.port)
|
|
sock = socket.create_connection((cfg.host, cfg.port), timeout=CONNECT_TIMEOUT)
|
|
transport = paramiko.Transport(sock)
|
|
transport.banner_timeout = CONNECT_TIMEOUT
|
|
transport.handshake_timeout = CONNECT_TIMEOUT
|
|
log.debug("Starting SSH handshake")
|
|
if cfg.auth_method == "key" and cfg.key:
|
|
key = _load_private_key(cfg.key)
|
|
transport.connect(username=cfg.user, pkey=key)
|
|
else:
|
|
transport.connect(username=cfg.user, password=cfg.password)
|
|
log.debug("SSH authenticated")
|
|
return transport
|
|
|
|
|
|
def test_connection(cfg: SFTPConfig) -> tuple[bool, str]:
|
|
try:
|
|
transport = _make_transport(cfg)
|
|
sftp = paramiko.SFTPClient.from_transport(transport)
|
|
try:
|
|
entries = sftp.listdir(cfg.remote_path)
|
|
zip_count = sum(1 for e in entries if e.lower().endswith(".zip"))
|
|
return True, f"Connected to {cfg.host}. {len(entries)} item(s) in {cfg.remote_path} ({zip_count} zip file(s) at top level)."
|
|
finally:
|
|
sftp.close()
|
|
transport.close()
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
|
|
def list_new_zips(cfg: SFTPConfig) -> list[RemoteZip]:
|
|
transport = _make_transport(cfg)
|
|
sftp = paramiko.SFTPClient.from_transport(transport)
|
|
try:
|
|
all_zips = _walk_zips(sftp, cfg.remote_path)
|
|
new_zips = [z for z in all_zips if not db.is_zip_processed(z.remote_path)]
|
|
log.info("Remote: %d zip(s) total, %d new", len(all_zips), len(new_zips))
|
|
return new_zips
|
|
finally:
|
|
sftp.close()
|
|
transport.close()
|
|
|
|
|
|
def download(cfg: SFTPConfig, remote_zip: RemoteZip, dest_dir: str) -> Path:
|
|
dest = Path(dest_dir)
|
|
dest.mkdir(parents=True, exist_ok=True)
|
|
local_path = dest / Path(remote_zip.remote_path).name
|
|
transport = _make_transport(cfg)
|
|
sftp = paramiko.SFTPClient.from_transport(transport)
|
|
try:
|
|
log.info("Downloading %s → %s", remote_zip.remote_path, local_path)
|
|
sftp.get(remote_zip.remote_path, str(local_path))
|
|
finally:
|
|
sftp.close()
|
|
transport.close()
|
|
return local_path
|
|
|
|
|
|
def _walk_zips(sftp: paramiko.SFTPClient, remote_dir: str) -> list[RemoteZip]:
|
|
results: list[RemoteZip] = []
|
|
try:
|
|
entries = sftp.listdir_attr(remote_dir)
|
|
except IOError as e:
|
|
log.warning("Cannot list %s: %s", remote_dir, e)
|
|
return results
|
|
|
|
for entry in entries:
|
|
full_path = posixpath.join(remote_dir, entry.filename)
|
|
if stat.S_ISDIR(entry.st_mode):
|
|
results.extend(_walk_zips(sftp, full_path))
|
|
elif entry.filename.lower().endswith(".zip"):
|
|
results.append(RemoteZip(remote_path=full_path, file_size=entry.st_size or 0))
|
|
return results
|