from collections import Counter from typing import Dict, Optional import numpy as np import pandas as pd from pipeline.constants import BIRCH_THRESHOLD, BIRCH_BRANCHING, MIN_CLUSTER_SIZE class EnhancedBIRCHClusterer: """Trajectory clustering using sklearn Birch with a simple K-means fallback. Based on the enhanced-BIRCH approach (Yan, Yang et al.): 1. Resample each trajectory to a fixed-length vector. 2. Build a BIRCH CF-tree for memory-efficient hierarchical clustering. 3. Small clusters (< MIN_CLUSTER_SIZE) are relabelled as noise (-1). """ def __init__( self, threshold: float = BIRCH_THRESHOLD, branching: int = BIRCH_BRANCHING, n_clusters: Optional[int] = None, ) -> None: self.threshold = threshold self.branching = branching self.n_clusters = n_clusters self._model = None def _traj_to_vector(self, df_vessel: pd.DataFrame, n_points: int = 20) -> np.ndarray: """Convert a vessel trajectory DataFrame to a fixed-length vector. Linearly samples n_points from the trajectory and interleaves lat/lon values, then normalises to zero mean / unit variance. """ lats = df_vessel['lat'].values lons = df_vessel['lon'].values idx = np.linspace(0, len(lats) - 1, n_points).astype(int) vec = np.concatenate([lats[idx], lons[idx]]) vec = (vec - vec.mean()) / (vec.std() + 1e-9) return vec def fit_predict(self, vessels: Dict[str, pd.DataFrame]) -> Dict[str, int]: """Cluster vessel trajectories. Args: vessels: mapping of mmsi -> resampled trajectory DataFrame. Returns: Mapping of mmsi -> cluster_id. Vessels in small clusters are assigned cluster_id -1 (noise). Vessels with fewer than 20 points are excluded from the result. """ mmsi_list: list[str] = [] vectors: list[np.ndarray] = [] for mmsi, df_v in vessels.items(): if len(df_v) < 20: continue mmsi_list.append(mmsi) vectors.append(self._traj_to_vector(df_v)) if len(vectors) < 3: return {m: 0 for m in mmsi_list} X = np.array(vectors) try: from sklearn.cluster import Birch model = Birch( threshold=self.threshold, branching_factor=self.branching, n_clusters=self.n_clusters, ) labels = model.fit_predict(X) self._model = model except ImportError: labels = self._simple_cluster(X) cnt = Counter(labels) labels = np.array([lbl if cnt[lbl] >= MIN_CLUSTER_SIZE else -1 for lbl in labels]) return dict(zip(mmsi_list, labels.tolist())) @staticmethod def _simple_cluster(X: np.ndarray, k: int = 5) -> np.ndarray: """Fallback K-means used when sklearn is unavailable.""" n = len(X) k = min(k, n) centers = X[np.random.choice(n, k, replace=False)] labels = np.zeros(n, dtype=int) for _ in range(20): dists = np.array([[np.linalg.norm(x - c) for c in centers] for x in X]) labels = dists.argmin(axis=1) new_centers = np.array( [X[labels == i].mean(axis=0) if (labels == i).any() else centers[i] for i in range(k)] ) if np.allclose(centers, new_centers, atol=1e-6): break centers = new_centers return labels