import sys import os from pathlib import Path from contextlib import asynccontextmanager import asyncio from concurrent.futures import ThreadPoolExecutor from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.responses import Response, FileResponse import subprocess import rasterio import numpy as np from PIL import Image from PIL.ExifTags import TAGS import io import base64 from pyproj import Transformer from extract_data import get_metadata as get_meta from extract_data import get_oil_type as get_oil import time from typing import List, Optional import shutil from datetime import datetime from collections import Counter # mx15hdi 파이프라인 모듈 임포트를 위한 sys.path 설정 _BASE_DIR = Path(__file__).parent sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Detect')) sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Metadata' / 'Scripts')) sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Georeference' / 'Scripts')) sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Polygon' / 'Scripts')) from Inference import load_model, run_inference from Export_Metadata_mx15hdi import run_metadata_export from Create_Georeferenced_Images_nadir import run_georeference from Oilshape import run_oilshape # AI 모델 (서버 시작 시 1회 로드) _model = None # CPU/GPU 바운드 작업용 스레드 풀 _executor = ThreadPoolExecutor(max_workers=4) @asynccontextmanager async def lifespan(app: FastAPI): """서버 시작 시 AI 모델을 1회 로드하고, 종료 시 해제한다.""" global _model print("AI 모델 로딩 중 (epoch_165.pth)...") _model = load_model() print("AI 모델 로드 완료") yield _model = None app = FastAPI(lifespan=lifespan) def check_gps_info(image_path: str): # Pillow로 이미지 열기 image = Image.open(image_path) # EXIF 데이터 추출 exifdata = image.getexif() if not exifdata: print("EXIF 정보를 찾을 수 없습니다.") return False # GPS 정보 추출 gps_ifd = exifdata.get_ifd(0x8825) # GPS IFD 태그 if not gps_ifd: print("GPS 정보를 찾을 수 없습니다.") return False return True def check_camera_info(image_file): # Pillow로 이미지 열기 image = Image.open(image_file) # EXIF 데이터 추출 exifdata = image.getexif() if not exifdata: print("EXIF 정보를 찾을 수 없습니다.") return False for tag_id, value in exifdata.items(): tag_name = TAGS.get(tag_id, tag_id) if tag_name == "Model": return value.strip() if isinstance(value, str) else value async def _run_mx15hdi_pipeline(file_id: str): """ mx15hdi 파이프라인을 in-process로 실행한다. - Step 1 (AI 추론) + Step 2 (메타데이터 추출) 병렬 실행 - Step 3 (지리참조) → Step 4 (폴리곤 추출) 순차 실행 - 중간 파일 I/O 없이 numpy 배열을 메모리로 전달 """ loop = asyncio.get_event_loop() # Step 1 + Step 2 병렬 실행 — inference_cache 캡처 inference_cache, _ = await asyncio.gather( loop.run_in_executor(_executor, run_inference, _model, file_id), loop.run_in_executor(_executor, run_metadata_export, file_id), ) # Step 3: Georeference — inference_cache 메모리로 전달, georef_cache 반환 georef_cache = await loop.run_in_executor( _executor, run_georeference, file_id, inference_cache ) # Step 4: Polygon 추출 — georef_cache 메모리로 전달 (Mask_Tif 디스크 읽기 없음) await loop.run_in_executor(_executor, run_oilshape, file_id, georef_cache) # 전체 과정을 구동하는 api @app.post("/run-script/") async def run_script( # pollId: int = Form(...), camTy: str = Form(...), fileId: str = Form(...), image: UploadFile = File(...) ): try: print("start") start_time = time.perf_counter() if camTy not in ["mx15hdi", "starsafire"]: raise HTTPException(status_code=400, detail="string1 must be 'mx15hdi' or 'starsafire'") # 저장할 이미지 경로 설정 upload_dir = os.path.join(camTy, "Metadata/Image/Original_Images", fileId) os.makedirs(upload_dir, exist_ok=True) # 이미지 파일 저장 image_path = os.path.join(upload_dir, image.filename) with open(image_path, "wb") as f: f.write(await image.read()) gps_flage = check_gps_info(image_path) if not gps_flage: return {"detail": "GPS Infomation Not Found"} if camTy == "mx15hdi": # in-process 파이프라인 실행 (모델 재로딩 없음, Step1+2 병렬) await _run_mx15hdi_pipeline(fileId) else: # starsafire: 기존 subprocess 방식 유지 script_dir = os.path.join(os.getcwd(), camTy, "Main") script_file = "Combine_module.py" script_path = os.path.join(script_dir, script_file) if not os.path.exists(script_path): raise HTTPException(status_code=404, detail="Script not found") result = subprocess.run( ["python", script_file, fileId], cwd=script_dir, capture_output=True, text=True, timeout=300 ) print(f"Subprocess stdout: {result.stdout}") print(f"Subprocess stderr: {result.stderr}") meta_string = get_meta(camTy, fileId) oil_data = get_oil(camTy, fileId) end_time = time.perf_counter() print(f"Run time: {end_time - start_time:.4f} sec") return { "meta": meta_string, "data": oil_data } except subprocess.TimeoutExpired: raise HTTPException(status_code=500, detail="Script execution timed out") except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/get-metadata/{camTy}/{fileId}") async def get_metadata(camTy: str, fileId: str): try: meta_string = get_meta(camTy, fileId) oil_data = get_oil(camTy, fileId) return { "meta": meta_string, "data": oil_data } except subprocess.TimeoutExpired: raise HTTPException(status_code=500, detail="Script execution timed out") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/get-original-image/{camTy}/{fileId}") async def get_original_image(camTy: str, fileId: str): try: image_path = os.path.join(camTy, "Metadata/Image/Original_Images", fileId) files = os.listdir(image_path) target_file = [f for f in files if f.endswith(".png") or f.endswith(".jpg")] image_file = os.path.join(image_path, target_file[0]) with open(image_file, "rb") as origin_image: base64_string = base64.b64encode(origin_image.read()).decode("utf-8") print(base64_string[:100]) return base64_string except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/get-image/{camTy}/{fileId}") async def get_image(camTy: str, fileId: str): try: tif_file_path = os.path.join(camTy, "Georeference/Tif", fileId) files = os.listdir(tif_file_path) target_file = [f for f in files if f.endswith(".tif")] tif_file = os.path.join(tif_file_path, target_file[0]) with rasterio.open(tif_file) as dataset: crs = dataset.crs bounds = dataset.bounds if crs != "EPSG:4326": transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) minx, miny = transformer.transform(bounds.left, bounds.bottom) maxx, maxy = transformer.transform(bounds.right, bounds.top) print(minx, miny, maxx, maxy) data = dataset.read() if data.shape[0] == 1: image_data = data[0] else: image_data = np.moveaxis(data, 0, -1) image = Image.fromarray(image_data) buffer = io.BytesIO() image.save(buffer, format="PNG") base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8") print(base64_string[:100]) return { "minLon": minx, "minLat": miny, "maxLon": maxx, "maxLat": maxy, "image": base64_string } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) BASE_DIR = Path(__file__).parent PIC_GPS_SCRIPT = BASE_DIR / "pic_gps.py" @app.post("/stitch") async def stitch( files: List[UploadFile] = File(..., description="합성할 이미지 파일들 (2장 이상)"), fileId: str = Form(...) ): if len(files) < 2: raise HTTPException( status_code=400, detail="최소 2장 이상의 이미지가 필요합니다." ) try: today = datetime.now().strftime("%Y%m%d") upload_dir = BASE_DIR / "stitch" / fileId upload_dir.mkdir(parents=True, exist_ok=True) model_list = [] for idx, file in enumerate(files): model = check_camera_info(file.file) model_list.append(model) original_filename = file.filename or f"image_{idx}.jpg" filename = f"{model}_{idx:03d}_{original_filename}" file_path = upload_dir / filename output_filename = f"stitched_{fileId}.jpg" output_path = upload_dir / output_filename # 파일 저장 with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) model_counter = Counter(model_list) most_common_model = model_counter.most_common(1) cmd = [ "python", str(PIC_GPS_SCRIPT), "--mode", "drone", "--input", str(upload_dir), "--out", str(output_path), "--model", most_common_model[0][0], "--enhance" ] print(cmd) result = subprocess.run( cmd, capture_output=True, text=True, timeout=300 ) print(f"Subprocess stdout: {result.stdout}") if result.returncode != 0: print(f"Subprocess stderr: {result.stderr}") raise HTTPException(status_code=500, detail=f"Script failed: {result.stderr}") return FileResponse( path=str(output_path), media_type="image/jpeg", filename=output_filename ) except HTTPException: raise if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5001)