import numpy as np import xarray as xr from datetime import datetime import pandas as pd from logger import get_logger from utils import find_time_index, convert_and_round logger = get_logger("extractUvFull") def extract_uv_full(nc_file, target_time, category, skip=5, lon_range=None, lat_range=None): """ NetCDF 파일 전체에서 선택한 시간의 u, v 데이터 추출 (일정 간격으로 샘플링) """ ds = xr.open_dataset(nc_file) time_idx, selected_time = find_time_index(ds, target_time) lon = ds['lon'].values lat = ds['lat'].values if lon.ndim == 1 and lat.ndim == 1: lon_2d, lat_2d = np.meshgrid(lon, lat) else: lon_2d = lon lat_2d = lat if category == "wind": u_data = ds['x_wind'].values v_data = ds['y_wind'].values else: u_data = ds['ssu'].values v_data = ds['ssv'].values if u_data.ndim == 3: u_full = u_data[time_idx] v_full = v_data[time_idx] elif u_data.ndim == 4: u_full = u_data[time_idx, 0] v_full = v_data[time_idx, 0] else: u_full = u_data v_full = v_data if lon_range is not None or lat_range is not None: mask = np.ones(lon_2d.shape, dtype=bool) if lon_range is not None: min_lon, max_lon = lon_range mask = mask & (lon_2d >= min_lon) & (lon_2d <= max_lon) if lat_range is not None: min_lat, max_lat = lat_range mask = mask & (lat_2d >= min_lat) & (lat_2d <= max_lat) rows, cols = np.where(mask) if len(rows) == 0: raise ValueError("No data within specified range") min_row, max_row = rows.min(), rows.max() min_col, max_col = cols.min(), cols.max() u_full = u_full[min_row:max_row+1, min_col:max_col+1] v_full = v_full[min_row:max_row+1, min_col:max_col+1] lon_2d = lon_2d[min_row:max_row+1, min_col:max_col+1] lat_2d = lat_2d[min_row:max_row+1, min_col:max_col+1] u_region = u_full[::skip, ::skip] v_region = v_full[::skip, ::skip] lon_region = lon_2d[::skip, ::skip] lat_region = lat_2d[::skip, ::skip] logger.debug(f"Original size: {u_full.shape}") logger.debug(f"Sampled size: {u_region.shape}") land_mask = (u_region == 0) & (v_region == 0) u_list = convert_and_round(u_region, land_mask) v_list = convert_and_round(v_region, land_mask) lon_intervals = [] for i in range(lon_region.shape[1] - 1): interval = lon_region[0, i+1] - lon_region[0, i] lon_intervals.append(float(interval)) lat_intervals = [] for i in range(lat_region.shape[0] - 1): interval = lat_region[i+1, 0] - lat_region[i, 0] lat_intervals.append(float(interval)) bound_lon_lat = { "top": float(lat_region.max()), "bottom": float(lat_region.min()), "left": float(lon_region.min()), "right": float(lon_region.max()) } model_fcst_dt = selected_time.strftime("%Y%m%d%H") result = { "data": { "modelFcstDt": model_fcst_dt, "values": [ u_list, v_list ] }, "grid": { "lonInterval": lon_intervals, "boundLonLat": bound_lon_lat, "rows": int(u_region.shape[0]), "cols": int(u_region.shape[1]), "latInterval": lat_intervals } } ds.close() return result