iran prediction 47개 Python 파일을 prediction/ 디렉토리로 복제: - algorithms/ 14개 분석 알고리즘 (어구추론, 다크베셀, 스푸핑, 환적, 위험도 등) - pipeline/ 7단계 분류 파이프라인 - cache/vessel_store (24h 슬라이딩 윈도우) - db/ 어댑터 (snpdb 원본조회, kcgdb 결과저장) - chat/ AI 채팅 (Ollama, 후순위) - data/ 정적 데이터 (기선, 특정어업수역 GeoJSON) config.py를 kcgaidb로 재구성 (DB명, 사용자, 비밀번호) DB 연결 검증 완료 (kcgaidb 37개 테이블 접근 확인) Makefile에 dev-prediction / dev-all 타겟 추가 CLAUDE.md에 prediction 섹션 추가 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
from collections import Counter
|
|
from typing import Dict, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from pipeline.constants import BIRCH_THRESHOLD, BIRCH_BRANCHING, MIN_CLUSTER_SIZE
|
|
|
|
|
|
class EnhancedBIRCHClusterer:
|
|
"""Trajectory clustering using sklearn Birch with a simple K-means fallback.
|
|
|
|
Based on the enhanced-BIRCH approach (Yan, Yang et al.):
|
|
1. Resample each trajectory to a fixed-length vector.
|
|
2. Build a BIRCH CF-tree for memory-efficient hierarchical clustering.
|
|
3. Small clusters (< MIN_CLUSTER_SIZE) are relabelled as noise (-1).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
threshold: float = BIRCH_THRESHOLD,
|
|
branching: int = BIRCH_BRANCHING,
|
|
n_clusters: Optional[int] = None,
|
|
) -> None:
|
|
self.threshold = threshold
|
|
self.branching = branching
|
|
self.n_clusters = n_clusters
|
|
self._model = None
|
|
|
|
def _traj_to_vector(self, df_vessel: pd.DataFrame, n_points: int = 20) -> np.ndarray:
|
|
"""Convert a vessel trajectory DataFrame to a fixed-length vector.
|
|
|
|
Linearly samples n_points from the trajectory and interleaves lat/lon
|
|
values, then normalises to zero mean / unit variance.
|
|
"""
|
|
lats = df_vessel['lat'].values
|
|
lons = df_vessel['lon'].values
|
|
idx = np.linspace(0, len(lats) - 1, n_points).astype(int)
|
|
vec = np.concatenate([lats[idx], lons[idx]])
|
|
vec = (vec - vec.mean()) / (vec.std() + 1e-9)
|
|
return vec
|
|
|
|
def fit_predict(self, vessels: Dict[str, pd.DataFrame]) -> Dict[str, int]:
|
|
"""Cluster vessel trajectories.
|
|
|
|
Args:
|
|
vessels: mapping of mmsi -> resampled trajectory DataFrame.
|
|
|
|
Returns:
|
|
Mapping of mmsi -> cluster_id. Vessels in small clusters are
|
|
assigned cluster_id -1 (noise). Vessels with fewer than 20
|
|
points are excluded from the result.
|
|
"""
|
|
mmsi_list: list[str] = []
|
|
vectors: list[np.ndarray] = []
|
|
|
|
for mmsi, df_v in vessels.items():
|
|
if len(df_v) < 20:
|
|
continue
|
|
mmsi_list.append(mmsi)
|
|
vectors.append(self._traj_to_vector(df_v))
|
|
|
|
if len(vectors) < 3:
|
|
return {m: 0 for m in mmsi_list}
|
|
|
|
X = np.array(vectors)
|
|
|
|
try:
|
|
from sklearn.cluster import Birch
|
|
model = Birch(
|
|
threshold=self.threshold,
|
|
branching_factor=self.branching,
|
|
n_clusters=self.n_clusters,
|
|
)
|
|
labels = model.fit_predict(X)
|
|
self._model = model
|
|
except ImportError:
|
|
labels = self._simple_cluster(X)
|
|
|
|
cnt = Counter(labels)
|
|
labels = np.array([lbl if cnt[lbl] >= MIN_CLUSTER_SIZE else -1 for lbl in labels])
|
|
|
|
return dict(zip(mmsi_list, labels.tolist()))
|
|
|
|
@staticmethod
|
|
def _simple_cluster(X: np.ndarray, k: int = 5) -> np.ndarray:
|
|
"""Fallback K-means used when sklearn is unavailable."""
|
|
n = len(X)
|
|
k = min(k, n)
|
|
centers = X[np.random.choice(n, k, replace=False)]
|
|
labels = np.zeros(n, dtype=int)
|
|
for _ in range(20):
|
|
dists = np.array([[np.linalg.norm(x - c) for c in centers] for x in X])
|
|
labels = dists.argmin(axis=1)
|
|
new_centers = np.array(
|
|
[X[labels == i].mean(axis=0) if (labels == i).any() else centers[i] for i in range(k)]
|
|
)
|
|
if np.allclose(centers, new_centers, atol=1e-6):
|
|
break
|
|
centers = new_centers
|
|
return labels
|