"""어구 모선 추론 episode continuity + prior bonus helper.""" from __future__ import annotations import json import math from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Iterable, Optional from uuid import uuid4 from config import qualified_table GEAR_GROUP_EPISODES = qualified_table('gear_group_episodes') GEAR_GROUP_EPISODE_SNAPSHOTS = qualified_table('gear_group_episode_snapshots') GEAR_GROUP_PARENT_CANDIDATE_SNAPSHOTS = qualified_table('gear_group_parent_candidate_snapshots') GEAR_PARENT_LABEL_SESSIONS = qualified_table('gear_parent_label_sessions') _ACTIVE_EPISODE_WINDOW_HOURS = 6 _EPISODE_PRIOR_WINDOW_HOURS = 24 _LINEAGE_PRIOR_WINDOW_DAYS = 7 _LABEL_PRIOR_WINDOW_DAYS = 30 _CONTINUITY_SCORE_THRESHOLD = 0.45 _MERGE_SCORE_THRESHOLD = 0.35 _CENTER_DISTANCE_THRESHOLD_NM = 12.0 _EPISODE_PRIOR_MAX = 0.10 _LINEAGE_PRIOR_MAX = 0.05 _LABEL_PRIOR_MAX = 0.10 _TOTAL_PRIOR_CAP = 0.20 def _clamp(value: float, floor: float = 0.0, ceil: float = 1.0) -> float: return max(floor, min(ceil, value)) def _haversine_nm(lat1: float, lon1: float, lat2: float, lon2: float) -> float: earth_radius_nm = 3440.065 phi1 = math.radians(lat1) phi2 = math.radians(lat2) dphi = math.radians(lat2 - lat1) dlam = math.radians(lon2 - lon1) a = math.sin(dphi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlam / 2) ** 2 return earth_radius_nm * 2 * math.atan2(math.sqrt(a), math.sqrt(max(0.0, 1 - a))) def _json_list(value: Any) -> list[str]: if value is None: return [] if isinstance(value, list): return [str(item) for item in value if item] try: parsed = json.loads(value) except Exception: return [] if isinstance(parsed, list): return [str(item) for item in parsed if item] return [] @dataclass class GroupEpisodeInput: group_key: str normalized_parent_name: str sub_cluster_id: int member_mmsis: list[str] member_count: int center_lat: float center_lon: float @property def key(self) -> tuple[str, int]: return (self.group_key, self.sub_cluster_id) @dataclass class EpisodeState: episode_id: str lineage_key: str group_key: str normalized_parent_name: str current_sub_cluster_id: int member_mmsis: list[str] member_count: int center_lat: float center_lon: float last_snapshot_time: datetime status: str @dataclass class EpisodeAssignment: group_key: str sub_cluster_id: int normalized_parent_name: str episode_id: str continuity_source: str continuity_score: float split_from_episode_id: Optional[str] merged_from_episode_ids: list[str] member_mmsis: list[str] member_count: int center_lat: float center_lon: float @property def key(self) -> tuple[str, int]: return (self.group_key, self.sub_cluster_id) @dataclass class EpisodePlan: assignments: dict[tuple[str, int], EpisodeAssignment] expired_episode_ids: set[str] merged_episode_targets: dict[str, str] def _member_jaccard(left: Iterable[str], right: Iterable[str]) -> tuple[float, int]: left_set = {item for item in left if item} right_set = {item for item in right if item} if not left_set and not right_set: return 0.0, 0 overlap = len(left_set & right_set) union = len(left_set | right_set) return (overlap / union if union else 0.0), overlap def continuity_score(current: GroupEpisodeInput, previous: EpisodeState) -> tuple[float, int, float]: jaccard, overlap_count = _member_jaccard(current.member_mmsis, previous.member_mmsis) distance_nm = _haversine_nm(current.center_lat, current.center_lon, previous.center_lat, previous.center_lon) center_support = _clamp(1.0 - (distance_nm / _CENTER_DISTANCE_THRESHOLD_NM)) score = _clamp((0.75 * jaccard) + (0.25 * center_support)) return round(score, 6), overlap_count, round(distance_nm, 3) def load_active_episode_states(conn, lineage_keys: list[str]) -> dict[str, list[EpisodeState]]: if not lineage_keys: return {} cur = conn.cursor() try: cur.execute( f""" SELECT episode_id, lineage_key, group_key, normalized_parent_name, current_sub_cluster_id, current_member_mmsis, current_member_count, ST_Y(current_center_point) AS center_lat, ST_X(current_center_point) AS center_lon, last_snapshot_time, status FROM {GEAR_GROUP_EPISODES} WHERE lineage_key = ANY(%s) AND status = 'ACTIVE' AND last_snapshot_time >= NOW() - (%s * INTERVAL '1 hour') ORDER BY lineage_key, last_snapshot_time DESC, episode_id ASC """, (lineage_keys, _ACTIVE_EPISODE_WINDOW_HOURS), ) result: dict[str, list[EpisodeState]] = {} for row in cur.fetchall(): state = EpisodeState( episode_id=row[0], lineage_key=row[1], group_key=row[2], normalized_parent_name=row[3], current_sub_cluster_id=int(row[4] or 0), member_mmsis=_json_list(row[5]), member_count=int(row[6] or 0), center_lat=float(row[7] or 0.0), center_lon=float(row[8] or 0.0), last_snapshot_time=row[9], status=row[10], ) result.setdefault(state.lineage_key, []).append(state) return result finally: cur.close() def group_to_episode_input(group: dict[str, Any], normalized_parent_name: str) -> GroupEpisodeInput: members = group.get('members') or [] member_mmsis = sorted({str(member.get('mmsi')) for member in members if member.get('mmsi')}) member_count = len(member_mmsis) if members: center_lat = sum(float(member['lat']) for member in members) / len(members) center_lon = sum(float(member['lon']) for member in members) / len(members) else: center_lat = 0.0 center_lon = 0.0 return GroupEpisodeInput( group_key=group['parent_name'], normalized_parent_name=normalized_parent_name, sub_cluster_id=int(group.get('sub_cluster_id', 0)), member_mmsis=member_mmsis, member_count=member_count, center_lat=center_lat, center_lon=center_lon, ) def build_episode_plan( groups: list[GroupEpisodeInput], previous_by_lineage: dict[str, list[EpisodeState]], ) -> EpisodePlan: assignments: dict[tuple[str, int], EpisodeAssignment] = {} expired_episode_ids: set[str] = set() merged_episode_targets: dict[str, str] = {} groups_by_lineage: dict[str, list[GroupEpisodeInput]] = {} for group in groups: groups_by_lineage.setdefault(group.normalized_parent_name, []).append(group) for lineage_key, current_groups in groups_by_lineage.items(): previous_groups = previous_by_lineage.get(lineage_key, []) qualified_matches: dict[tuple[str, int], list[tuple[EpisodeState, float, int, float]]] = {} prior_to_currents: dict[str, list[tuple[GroupEpisodeInput, float, int, float]]] = {} for current in current_groups: for previous in previous_groups: score, overlap_count, distance_nm = continuity_score(current, previous) if score >= _CONTINUITY_SCORE_THRESHOLD or ( overlap_count > 0 and distance_nm <= _CENTER_DISTANCE_THRESHOLD_NM ): qualified_matches.setdefault(current.key, []).append((previous, score, overlap_count, distance_nm)) prior_to_currents.setdefault(previous.episode_id, []).append((current, score, overlap_count, distance_nm)) consumed_previous_ids: set[str] = set() assigned_current_keys: set[tuple[str, int]] = set() for current in current_groups: matches = sorted( qualified_matches.get(current.key, []), key=lambda item: (item[1], item[2], -item[3], item[0].last_snapshot_time), reverse=True, ) merge_candidates = [ item for item in matches if item[1] >= _MERGE_SCORE_THRESHOLD ] if len(merge_candidates) >= 2: episode_id = f"ep-{uuid4().hex[:12]}" merged_ids = [item[0].episode_id for item in merge_candidates] assignments[current.key] = EpisodeAssignment( group_key=current.group_key, sub_cluster_id=current.sub_cluster_id, normalized_parent_name=current.normalized_parent_name, episode_id=episode_id, continuity_source='MERGE_NEW', continuity_score=round(max(item[1] for item in merge_candidates), 6), split_from_episode_id=None, merged_from_episode_ids=merged_ids, member_mmsis=current.member_mmsis, member_count=current.member_count, center_lat=current.center_lat, center_lon=current.center_lon, ) assigned_current_keys.add(current.key) for merged_id in merged_ids: consumed_previous_ids.add(merged_id) merged_episode_targets[merged_id] = episode_id previous_ranked = sorted( previous_groups, key=lambda item: item.last_snapshot_time, reverse=True, ) for previous in previous_ranked: if previous.episode_id in consumed_previous_ids: continue matches = [ item for item in prior_to_currents.get(previous.episode_id, []) if item[0].key not in assigned_current_keys ] if not matches: continue matches.sort(key=lambda item: (item[1], item[2], -item[3]), reverse=True) current, score, _, _ = matches[0] split_candidate_count = len(prior_to_currents.get(previous.episode_id, [])) assignments[current.key] = EpisodeAssignment( group_key=current.group_key, sub_cluster_id=current.sub_cluster_id, normalized_parent_name=current.normalized_parent_name, episode_id=previous.episode_id, continuity_source='SPLIT_CONTINUE' if split_candidate_count > 1 else 'CONTINUED', continuity_score=score, split_from_episode_id=None, merged_from_episode_ids=[], member_mmsis=current.member_mmsis, member_count=current.member_count, center_lat=current.center_lat, center_lon=current.center_lon, ) assigned_current_keys.add(current.key) consumed_previous_ids.add(previous.episode_id) for current in current_groups: if current.key in assigned_current_keys: continue matches = sorted( qualified_matches.get(current.key, []), key=lambda item: (item[1], item[2], -item[3], item[0].last_snapshot_time), reverse=True, ) split_from_episode_id = None continuity_source = 'NEW' continuity_score_value = 0.0 if matches: best_previous, score, _, _ = matches[0] split_from_episode_id = best_previous.episode_id continuity_source = 'SPLIT_NEW' continuity_score_value = score assignments[current.key] = EpisodeAssignment( group_key=current.group_key, sub_cluster_id=current.sub_cluster_id, normalized_parent_name=current.normalized_parent_name, episode_id=f"ep-{uuid4().hex[:12]}", continuity_source=continuity_source, continuity_score=continuity_score_value, split_from_episode_id=split_from_episode_id, merged_from_episode_ids=[], member_mmsis=current.member_mmsis, member_count=current.member_count, center_lat=current.center_lat, center_lon=current.center_lon, ) assigned_current_keys.add(current.key) current_previous_ids = {assignment.episode_id for assignment in assignments.values() if assignment.normalized_parent_name == lineage_key} for previous in previous_groups: if previous.episode_id in merged_episode_targets: continue if previous.episode_id not in current_previous_ids: expired_episode_ids.add(previous.episode_id) return EpisodePlan( assignments=assignments, expired_episode_ids=expired_episode_ids, merged_episode_targets=merged_episode_targets, ) def load_episode_prior_stats(conn, episode_ids: list[str]) -> dict[tuple[str, str], dict[str, Any]]: if not episode_ids: return {} cur = conn.cursor() try: cur.execute( f""" SELECT episode_id, candidate_mmsi, COUNT(*) AS seen_count, SUM(CASE WHEN rank = 1 THEN 1 ELSE 0 END) AS top1_count, AVG(final_score) AS avg_score, MAX(observed_at) AS last_seen_at FROM {GEAR_GROUP_PARENT_CANDIDATE_SNAPSHOTS} WHERE episode_id = ANY(%s) AND observed_at >= NOW() - (%s * INTERVAL '1 hour') GROUP BY episode_id, candidate_mmsi """, (episode_ids, _EPISODE_PRIOR_WINDOW_HOURS), ) result: dict[tuple[str, str], dict[str, Any]] = {} for episode_id, candidate_mmsi, seen_count, top1_count, avg_score, last_seen_at in cur.fetchall(): result[(episode_id, candidate_mmsi)] = { 'seen_count': int(seen_count or 0), 'top1_count': int(top1_count or 0), 'avg_score': float(avg_score or 0.0), 'last_seen_at': last_seen_at, } return result finally: cur.close() def load_lineage_prior_stats(conn, lineage_keys: list[str]) -> dict[tuple[str, str], dict[str, Any]]: if not lineage_keys: return {} cur = conn.cursor() try: cur.execute( f""" SELECT normalized_parent_name, candidate_mmsi, COUNT(*) AS seen_count, SUM(CASE WHEN rank = 1 THEN 1 ELSE 0 END) AS top1_count, SUM(CASE WHEN rank <= 3 THEN 1 ELSE 0 END) AS top3_count, AVG(final_score) AS avg_score, MAX(observed_at) AS last_seen_at FROM {GEAR_GROUP_PARENT_CANDIDATE_SNAPSHOTS} WHERE normalized_parent_name = ANY(%s) AND observed_at >= NOW() - (%s * INTERVAL '1 day') GROUP BY normalized_parent_name, candidate_mmsi """, (lineage_keys, _LINEAGE_PRIOR_WINDOW_DAYS), ) result: dict[tuple[str, str], dict[str, Any]] = {} for lineage_key, candidate_mmsi, seen_count, top1_count, top3_count, avg_score, last_seen_at in cur.fetchall(): result[(lineage_key, candidate_mmsi)] = { 'seen_count': int(seen_count or 0), 'top1_count': int(top1_count or 0), 'top3_count': int(top3_count or 0), 'avg_score': float(avg_score or 0.0), 'last_seen_at': last_seen_at, } return result finally: cur.close() def load_label_prior_stats(conn, lineage_keys: list[str]) -> dict[tuple[str, str], dict[str, Any]]: if not lineage_keys: return {} cur = conn.cursor() try: cur.execute( f""" SELECT normalized_parent_name, label_parent_mmsi, COUNT(*) AS session_count, MAX(active_from) AS last_labeled_at FROM {GEAR_PARENT_LABEL_SESSIONS} WHERE normalized_parent_name = ANY(%s) AND active_from >= NOW() - (%s * INTERVAL '1 day') GROUP BY normalized_parent_name, label_parent_mmsi """, (lineage_keys, _LABEL_PRIOR_WINDOW_DAYS), ) result: dict[tuple[str, str], dict[str, Any]] = {} for lineage_key, candidate_mmsi, session_count, last_labeled_at in cur.fetchall(): result[(lineage_key, candidate_mmsi)] = { 'session_count': int(session_count or 0), 'last_labeled_at': last_labeled_at, } return result finally: cur.close() def _recency_support(observed_at: Optional[datetime], now: datetime, hours: float) -> float: if observed_at is None: return 0.0 if observed_at.tzinfo is None: observed_at = observed_at.replace(tzinfo=timezone.utc) delta_hours = max(0.0, (now - observed_at.astimezone(timezone.utc)).total_seconds() / 3600.0) return _clamp(1.0 - (delta_hours / hours)) def compute_prior_bonus_components( observed_at: datetime, normalized_parent_name: str, episode_id: str, candidate_mmsi: str, episode_prior_stats: dict[tuple[str, str], dict[str, Any]], lineage_prior_stats: dict[tuple[str, str], dict[str, Any]], label_prior_stats: dict[tuple[str, str], dict[str, Any]], ) -> dict[str, float]: episode_stats = episode_prior_stats.get((episode_id, candidate_mmsi), {}) lineage_stats = lineage_prior_stats.get((normalized_parent_name, candidate_mmsi), {}) label_stats = label_prior_stats.get((normalized_parent_name, candidate_mmsi), {}) episode_bonus = 0.0 if episode_stats: episode_bonus = _EPISODE_PRIOR_MAX * ( 0.35 * min(1.0, episode_stats.get('seen_count', 0) / 6.0) + 0.35 * min(1.0, episode_stats.get('top1_count', 0) / 3.0) + 0.15 * _clamp(float(episode_stats.get('avg_score', 0.0))) + 0.15 * _recency_support(episode_stats.get('last_seen_at'), observed_at, _EPISODE_PRIOR_WINDOW_HOURS) ) lineage_bonus = 0.0 if lineage_stats: lineage_bonus = _LINEAGE_PRIOR_MAX * ( 0.30 * min(1.0, lineage_stats.get('seen_count', 0) / 12.0) + 0.25 * min(1.0, lineage_stats.get('top3_count', 0) / 6.0) + 0.20 * min(1.0, lineage_stats.get('top1_count', 0) / 3.0) + 0.15 * _clamp(float(lineage_stats.get('avg_score', 0.0))) + 0.10 * _recency_support(lineage_stats.get('last_seen_at'), observed_at, _LINEAGE_PRIOR_WINDOW_DAYS * 24.0) ) label_bonus = 0.0 if label_stats: label_bonus = _LABEL_PRIOR_MAX * ( 0.70 * min(1.0, label_stats.get('session_count', 0) / 3.0) + 0.30 * _recency_support(label_stats.get('last_labeled_at'), observed_at, _LABEL_PRIOR_WINDOW_DAYS * 24.0) ) total = min(_TOTAL_PRIOR_CAP, episode_bonus + lineage_bonus + label_bonus) return { 'episodePriorBonus': round(episode_bonus, 6), 'lineagePriorBonus': round(lineage_bonus, 6), 'labelPriorBonus': round(label_bonus, 6), 'priorBonusTotal': round(total, 6), } def sync_episode_states(conn, observed_at: datetime, plan: EpisodePlan) -> None: cur = conn.cursor() try: if plan.expired_episode_ids: cur.execute( f""" UPDATE {GEAR_GROUP_EPISODES} SET status = 'EXPIRED', updated_at = %s WHERE episode_id = ANY(%s) """, (observed_at, list(plan.expired_episode_ids)), ) for previous_episode_id, merged_into_episode_id in plan.merged_episode_targets.items(): cur.execute( f""" UPDATE {GEAR_GROUP_EPISODES} SET status = 'MERGED', merged_into_episode_id = %s, updated_at = %s WHERE episode_id = %s """, (merged_into_episode_id, observed_at, previous_episode_id), ) for assignment in plan.assignments.values(): cur.execute( f""" INSERT INTO {GEAR_GROUP_EPISODES} ( episode_id, lineage_key, group_key, normalized_parent_name, current_sub_cluster_id, status, continuity_source, continuity_score, first_seen_at, last_seen_at, last_snapshot_time, current_member_count, current_member_mmsis, current_center_point, split_from_episode_id, merged_from_episode_ids, metadata, updated_at ) VALUES ( %s, %s, %s, %s, %s, 'ACTIVE', %s, %s, %s, %s, %s, %s, %s::jsonb, ST_SetSRID(ST_MakePoint(%s, %s), 4326), %s, %s::jsonb, '{{}}'::jsonb, %s ) ON CONFLICT (episode_id) DO UPDATE SET group_key = EXCLUDED.group_key, normalized_parent_name = EXCLUDED.normalized_parent_name, current_sub_cluster_id = EXCLUDED.current_sub_cluster_id, status = 'ACTIVE', continuity_source = EXCLUDED.continuity_source, continuity_score = EXCLUDED.continuity_score, last_seen_at = EXCLUDED.last_seen_at, last_snapshot_time = EXCLUDED.last_snapshot_time, current_member_count = EXCLUDED.current_member_count, current_member_mmsis = EXCLUDED.current_member_mmsis, current_center_point = EXCLUDED.current_center_point, split_from_episode_id = COALESCE(EXCLUDED.split_from_episode_id, {GEAR_GROUP_EPISODES}.split_from_episode_id), merged_from_episode_ids = EXCLUDED.merged_from_episode_ids, updated_at = EXCLUDED.updated_at """, ( assignment.episode_id, assignment.normalized_parent_name, assignment.group_key, assignment.normalized_parent_name, assignment.sub_cluster_id, assignment.continuity_source, assignment.continuity_score, observed_at, observed_at, observed_at, assignment.member_count, json.dumps(assignment.member_mmsis, ensure_ascii=False), assignment.center_lon, assignment.center_lat, assignment.split_from_episode_id, json.dumps(assignment.merged_from_episode_ids, ensure_ascii=False), observed_at, ), ) finally: cur.close() def insert_episode_snapshots( conn, observed_at: datetime, plan: EpisodePlan, snapshot_payloads: dict[tuple[str, int], dict[str, Any]], ) -> int: if not snapshot_payloads: return 0 rows: list[tuple[Any, ...]] = [] for key, payload in snapshot_payloads.items(): assignment = plan.assignments.get(key) if assignment is None: continue rows.append(( assignment.episode_id, assignment.normalized_parent_name, assignment.group_key, assignment.normalized_parent_name, assignment.sub_cluster_id, observed_at, assignment.member_count, json.dumps(assignment.member_mmsis, ensure_ascii=False), assignment.center_lon, assignment.center_lat, assignment.continuity_source, assignment.continuity_score, json.dumps(payload.get('parentEpisodeIds') or assignment.merged_from_episode_ids, ensure_ascii=False), payload.get('topCandidateMmsi'), payload.get('topCandidateScore'), payload.get('resolutionStatus'), json.dumps(payload.get('metadata') or {}, ensure_ascii=False), )) if not rows: return 0 cur = conn.cursor() try: from psycopg2.extras import execute_values execute_values( cur, f""" INSERT INTO {GEAR_GROUP_EPISODE_SNAPSHOTS} ( episode_id, lineage_key, group_key, normalized_parent_name, sub_cluster_id, observed_at, member_count, member_mmsis, center_point, continuity_source, continuity_score, parent_episode_ids, top_candidate_mmsi, top_candidate_score, resolution_status, metadata ) VALUES %s ON CONFLICT (episode_id, observed_at) DO NOTHING """, rows, template="(%s, %s, %s, %s, %s, %s, %s, %s::jsonb, ST_SetSRID(ST_MakePoint(%s, %s), 4326), %s, %s, %s::jsonb, %s, %s, %s, %s::jsonb)", page_size=200, ) return len(rows) finally: cur.close()