wing-ops/prediction/image/api.py

341 lines
11 KiB
Python

import sys
import os
from pathlib import Path
from contextlib import asynccontextmanager
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import Response, FileResponse
import subprocess
import rasterio
import numpy as np
from PIL import Image
from PIL.ExifTags import TAGS
import io
import base64
from pyproj import Transformer
from extract_data import get_metadata as get_meta
from extract_data import get_oil_type as get_oil
import time
from typing import List, Optional
import shutil
from datetime import datetime
from collections import Counter
# mx15hdi 파이프라인 모듈 임포트를 위한 sys.path 설정
_BASE_DIR = Path(__file__).parent
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Detect'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Metadata' / 'Scripts'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Georeference' / 'Scripts'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Polygon' / 'Scripts'))
from Inference import load_model, run_inference
from Export_Metadata_mx15hdi import run_metadata_export
from Create_Georeferenced_Images_nadir import run_georeference
from Oilshape import run_oilshape
# AI 모델 (서버 시작 시 1회 로드)
_model = None
# CPU/GPU 바운드 작업용 스레드 풀
_executor = ThreadPoolExecutor(max_workers=4)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""서버 시작 시 AI 모델을 1회 로드하고, 종료 시 해제한다."""
global _model
print("AI 모델 로딩 중 (epoch_165.pth)...")
_model = load_model()
print("AI 모델 로드 완료")
yield
_model = None
app = FastAPI(lifespan=lifespan)
def check_gps_info(image_path: str):
# Pillow로 이미지 열기
image = Image.open(image_path)
# EXIF 데이터 추출
exifdata = image.getexif()
if not exifdata:
print("EXIF 정보를 찾을 수 없습니다.")
return False
# GPS 정보 추출
gps_ifd = exifdata.get_ifd(0x8825) # GPS IFD 태그
if not gps_ifd:
print("GPS 정보를 찾을 수 없습니다.")
return False
return True
def check_camera_info(image_file):
# Pillow로 이미지 열기
image = Image.open(image_file)
# EXIF 데이터 추출
exifdata = image.getexif()
if not exifdata:
print("EXIF 정보를 찾을 수 없습니다.")
return False
for tag_id, value in exifdata.items():
tag_name = TAGS.get(tag_id, tag_id)
if tag_name == "Model":
return value.strip() if isinstance(value, str) else value
async def _run_mx15hdi_pipeline(file_id: str):
"""
mx15hdi 파이프라인을 in-process로 실행한다.
- Step 1 (AI 추론) + Step 2 (메타데이터 추출) 병렬 실행
- Step 3 (지리참조) → Step 4 (폴리곤 추출) 순차 실행
- 중간 파일 I/O 없이 numpy 배열을 메모리로 전달
"""
loop = asyncio.get_event_loop()
# Step 1 + Step 2 병렬 실행 — inference_cache 캡처
inference_cache, _ = await asyncio.gather(
loop.run_in_executor(_executor, run_inference, _model, file_id),
loop.run_in_executor(_executor, run_metadata_export, file_id),
)
# Step 3: Georeference — inference_cache 메모리로 전달, georef_cache 반환
georef_cache = await loop.run_in_executor(
_executor, run_georeference, file_id, inference_cache
)
# Step 4: Polygon 추출 — georef_cache 메모리로 전달 (Mask_Tif 디스크 읽기 없음)
await loop.run_in_executor(_executor, run_oilshape, file_id, georef_cache)
# 전체 과정을 구동하는 api
@app.post("/run-script/")
async def run_script(
# pollId: int = Form(...),
camTy: str = Form(...),
fileId: str = Form(...),
image: UploadFile = File(...)
):
try:
print("start")
start_time = time.perf_counter()
if camTy not in ["mx15hdi", "starsafire"]:
raise HTTPException(status_code=400, detail="string1 must be 'mx15hdi' or 'starsafire'")
# 저장할 이미지 경로 설정
upload_dir = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
os.makedirs(upload_dir, exist_ok=True)
# 이미지 파일 저장
image_path = os.path.join(upload_dir, image.filename)
with open(image_path, "wb") as f:
f.write(await image.read())
gps_flage = check_gps_info(image_path)
if not gps_flage:
return {"detail": "GPS Infomation Not Found"}
if camTy == "mx15hdi":
# in-process 파이프라인 실행 (모델 재로딩 없음, Step1+2 병렬)
await _run_mx15hdi_pipeline(fileId)
else:
# starsafire: 기존 subprocess 방식 유지
script_dir = os.path.join(os.getcwd(), camTy, "Main")
script_file = "Combine_module.py"
script_path = os.path.join(script_dir, script_file)
if not os.path.exists(script_path):
raise HTTPException(status_code=404, detail="Script not found")
result = subprocess.run(
["python", script_file, fileId],
cwd=script_dir,
capture_output=True,
text=True,
timeout=300
)
print(f"Subprocess stdout: {result.stdout}")
print(f"Subprocess stderr: {result.stderr}")
meta_string = get_meta(camTy, fileId)
oil_data = get_oil(camTy, fileId)
end_time = time.perf_counter()
print(f"Run time: {end_time - start_time:.4f} sec")
return {
"meta": meta_string,
"data": oil_data
}
except subprocess.TimeoutExpired:
raise HTTPException(status_code=500, detail="Script execution timed out")
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-metadata/{camTy}/{fileId}")
async def get_metadata(camTy: str, fileId: str):
try:
meta_string = get_meta(camTy, fileId)
oil_data = get_oil(camTy, fileId)
return {
"meta": meta_string,
"data": oil_data
}
except subprocess.TimeoutExpired:
raise HTTPException(status_code=500, detail="Script execution timed out")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-original-image/{camTy}/{fileId}")
async def get_original_image(camTy: str, fileId: str):
try:
image_path = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
files = os.listdir(image_path)
target_file = [f for f in files if f.endswith(".png") or f.endswith(".jpg")]
image_file = os.path.join(image_path, target_file[0])
with open(image_file, "rb") as origin_image:
base64_string = base64.b64encode(origin_image.read()).decode("utf-8")
print(base64_string[:100])
return base64_string
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-image/{camTy}/{fileId}")
async def get_image(camTy: str, fileId: str):
try:
tif_file_path = os.path.join(camTy, "Georeference/Tif", fileId)
files = os.listdir(tif_file_path)
target_file = [f for f in files if f.endswith(".tif")]
tif_file = os.path.join(tif_file_path, target_file[0])
with rasterio.open(tif_file) as dataset:
crs = dataset.crs
bounds = dataset.bounds
if crs != "EPSG:4326":
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
minx, miny = transformer.transform(bounds.left, bounds.bottom)
maxx, maxy = transformer.transform(bounds.right, bounds.top)
print(minx, miny, maxx, maxy)
data = dataset.read()
if data.shape[0] == 1:
image_data = data[0]
else:
image_data = np.moveaxis(data, 0, -1)
image = Image.fromarray(image_data)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
print(base64_string[:100])
return {
"minLon": minx,
"minLat": miny,
"maxLon": maxx,
"maxLat": maxy,
"image": base64_string
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
BASE_DIR = Path(__file__).parent
PIC_GPS_SCRIPT = BASE_DIR / "pic_gps.py"
@app.post("/stitch")
async def stitch(
files: List[UploadFile] = File(..., description="합성할 이미지 파일들 (2장 이상)"),
fileId: str = Form(...)
):
if len(files) < 2:
raise HTTPException(
status_code=400,
detail="최소 2장 이상의 이미지가 필요합니다."
)
try:
today = datetime.now().strftime("%Y%m%d")
upload_dir = BASE_DIR / "stitch" / fileId
upload_dir.mkdir(parents=True, exist_ok=True)
model_list = []
for idx, file in enumerate(files):
model = check_camera_info(file.file)
model_list.append(model)
original_filename = file.filename or f"image_{idx}.jpg"
filename = f"{model}_{idx:03d}_{original_filename}"
file_path = upload_dir / filename
output_filename = f"stitched_{fileId}.jpg"
output_path = upload_dir / output_filename
# 파일 저장
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
model_counter = Counter(model_list)
most_common_model = model_counter.most_common(1)
cmd = [
"python",
str(PIC_GPS_SCRIPT),
"--mode", "drone",
"--input", str(upload_dir),
"--out", str(output_path),
"--model", most_common_model[0][0],
"--enhance"
]
print(cmd)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300
)
print(f"Subprocess stdout: {result.stdout}")
if result.returncode != 0:
print(f"Subprocess stderr: {result.stderr}")
raise HTTPException(status_code=500, detail=f"Script failed: {result.stderr}")
return FileResponse(
path=str(output_path),
media_type="image/jpeg",
filename=output_filename
)
except HTTPException:
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)