341 lines
11 KiB
Python
341 lines
11 KiB
Python
import sys
|
|
import os
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
|
from fastapi.responses import Response, FileResponse
|
|
import subprocess
|
|
import rasterio
|
|
import numpy as np
|
|
from PIL import Image
|
|
from PIL.ExifTags import TAGS
|
|
import io
|
|
import base64
|
|
from pyproj import Transformer
|
|
from extract_data import get_metadata as get_meta
|
|
from extract_data import get_oil_type as get_oil
|
|
import time
|
|
|
|
from typing import List, Optional
|
|
import shutil
|
|
from datetime import datetime
|
|
from collections import Counter
|
|
|
|
# mx15hdi 파이프라인 모듈 임포트를 위한 sys.path 설정
|
|
_BASE_DIR = Path(__file__).parent
|
|
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Detect'))
|
|
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Metadata' / 'Scripts'))
|
|
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Georeference' / 'Scripts'))
|
|
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Polygon' / 'Scripts'))
|
|
|
|
from Inference import load_model, run_inference
|
|
from Export_Metadata_mx15hdi import run_metadata_export
|
|
from Create_Georeferenced_Images_nadir import run_georeference
|
|
from Oilshape import run_oilshape
|
|
|
|
# AI 모델 (서버 시작 시 1회 로드)
|
|
_model = None
|
|
# CPU/GPU 바운드 작업용 스레드 풀
|
|
_executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""서버 시작 시 AI 모델을 1회 로드하고, 종료 시 해제한다."""
|
|
global _model
|
|
print("AI 모델 로딩 중 (epoch_165.pth)...")
|
|
_model = load_model()
|
|
print("AI 모델 로드 완료")
|
|
yield
|
|
_model = None
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
def check_gps_info(image_path: str):
|
|
# Pillow로 이미지 열기
|
|
image = Image.open(image_path)
|
|
|
|
# EXIF 데이터 추출
|
|
exifdata = image.getexif()
|
|
|
|
if not exifdata:
|
|
print("EXIF 정보를 찾을 수 없습니다.")
|
|
return False
|
|
|
|
# GPS 정보 추출
|
|
gps_ifd = exifdata.get_ifd(0x8825) # GPS IFD 태그
|
|
if not gps_ifd:
|
|
print("GPS 정보를 찾을 수 없습니다.")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def check_camera_info(image_file):
|
|
# Pillow로 이미지 열기
|
|
image = Image.open(image_file)
|
|
|
|
# EXIF 데이터 추출
|
|
exifdata = image.getexif()
|
|
|
|
if not exifdata:
|
|
print("EXIF 정보를 찾을 수 없습니다.")
|
|
return False
|
|
|
|
for tag_id, value in exifdata.items():
|
|
tag_name = TAGS.get(tag_id, tag_id)
|
|
if tag_name == "Model":
|
|
return value.strip() if isinstance(value, str) else value
|
|
|
|
|
|
async def _run_mx15hdi_pipeline(file_id: str):
|
|
"""
|
|
mx15hdi 파이프라인을 in-process로 실행한다.
|
|
- Step 1 (AI 추론) + Step 2 (메타데이터 추출) 병렬 실행
|
|
- Step 3 (지리참조) → Step 4 (폴리곤 추출) 순차 실행
|
|
- 중간 파일 I/O 없이 numpy 배열을 메모리로 전달
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
|
|
# Step 1 + Step 2 병렬 실행 — inference_cache 캡처
|
|
inference_cache, _ = await asyncio.gather(
|
|
loop.run_in_executor(_executor, run_inference, _model, file_id),
|
|
loop.run_in_executor(_executor, run_metadata_export, file_id),
|
|
)
|
|
|
|
# Step 3: Georeference — inference_cache 메모리로 전달, georef_cache 반환
|
|
georef_cache = await loop.run_in_executor(
|
|
_executor, run_georeference, file_id, inference_cache
|
|
)
|
|
|
|
# Step 4: Polygon 추출 — georef_cache 메모리로 전달 (Mask_Tif 디스크 읽기 없음)
|
|
await loop.run_in_executor(_executor, run_oilshape, file_id, georef_cache)
|
|
|
|
|
|
# 전체 과정을 구동하는 api
|
|
@app.post("/run-script/")
|
|
async def run_script(
|
|
# pollId: int = Form(...),
|
|
camTy: str = Form(...),
|
|
fileId: str = Form(...),
|
|
image: UploadFile = File(...)
|
|
):
|
|
try:
|
|
print("start")
|
|
start_time = time.perf_counter()
|
|
|
|
if camTy not in ["mx15hdi", "starsafire"]:
|
|
raise HTTPException(status_code=400, detail="string1 must be 'mx15hdi' or 'starsafire'")
|
|
|
|
# 저장할 이미지 경로 설정
|
|
upload_dir = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
|
|
# 이미지 파일 저장
|
|
image_path = os.path.join(upload_dir, image.filename)
|
|
with open(image_path, "wb") as f:
|
|
f.write(await image.read())
|
|
|
|
gps_flage = check_gps_info(image_path)
|
|
if not gps_flage:
|
|
return {"detail": "GPS Infomation Not Found"}
|
|
|
|
if camTy == "mx15hdi":
|
|
# in-process 파이프라인 실행 (모델 재로딩 없음, Step1+2 병렬)
|
|
await _run_mx15hdi_pipeline(fileId)
|
|
else:
|
|
# starsafire: 기존 subprocess 방식 유지
|
|
script_dir = os.path.join(os.getcwd(), camTy, "Main")
|
|
script_file = "Combine_module.py"
|
|
script_path = os.path.join(script_dir, script_file)
|
|
|
|
if not os.path.exists(script_path):
|
|
raise HTTPException(status_code=404, detail="Script not found")
|
|
|
|
result = subprocess.run(
|
|
["python", script_file, fileId],
|
|
cwd=script_dir,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=300
|
|
)
|
|
print(f"Subprocess stdout: {result.stdout}")
|
|
print(f"Subprocess stderr: {result.stderr}")
|
|
|
|
meta_string = get_meta(camTy, fileId)
|
|
oil_data = get_oil(camTy, fileId)
|
|
end_time = time.perf_counter()
|
|
print(f"Run time: {end_time - start_time:.4f} sec")
|
|
|
|
return {
|
|
"meta": meta_string,
|
|
"data": oil_data
|
|
}
|
|
|
|
except subprocess.TimeoutExpired:
|
|
raise HTTPException(status_code=500, detail="Script execution timed out")
|
|
except Exception as e:
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/get-metadata/{camTy}/{fileId}")
|
|
async def get_metadata(camTy: str, fileId: str):
|
|
try:
|
|
meta_string = get_meta(camTy, fileId)
|
|
oil_data = get_oil(camTy, fileId)
|
|
|
|
return {
|
|
"meta": meta_string,
|
|
"data": oil_data
|
|
}
|
|
|
|
except subprocess.TimeoutExpired:
|
|
raise HTTPException(status_code=500, detail="Script execution timed out")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/get-original-image/{camTy}/{fileId}")
|
|
async def get_original_image(camTy: str, fileId: str):
|
|
try:
|
|
image_path = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
|
|
files = os.listdir(image_path)
|
|
target_file = [f for f in files if f.endswith(".png") or f.endswith(".jpg")]
|
|
image_file = os.path.join(image_path, target_file[0])
|
|
|
|
with open(image_file, "rb") as origin_image:
|
|
base64_string = base64.b64encode(origin_image.read()).decode("utf-8")
|
|
print(base64_string[:100])
|
|
|
|
return base64_string
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/get-image/{camTy}/{fileId}")
|
|
async def get_image(camTy: str, fileId: str):
|
|
try:
|
|
tif_file_path = os.path.join(camTy, "Georeference/Tif", fileId)
|
|
files = os.listdir(tif_file_path)
|
|
target_file = [f for f in files if f.endswith(".tif")]
|
|
tif_file = os.path.join(tif_file_path, target_file[0])
|
|
|
|
with rasterio.open(tif_file) as dataset:
|
|
crs = dataset.crs
|
|
|
|
bounds = dataset.bounds
|
|
|
|
if crs != "EPSG:4326":
|
|
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
|
|
minx, miny = transformer.transform(bounds.left, bounds.bottom)
|
|
maxx, maxy = transformer.transform(bounds.right, bounds.top)
|
|
|
|
print(minx, miny, maxx, maxy)
|
|
|
|
data = dataset.read()
|
|
if data.shape[0] == 1:
|
|
image_data = data[0]
|
|
else:
|
|
image_data = np.moveaxis(data, 0, -1)
|
|
|
|
image = Image.fromarray(image_data)
|
|
buffer = io.BytesIO()
|
|
image.save(buffer, format="PNG")
|
|
|
|
base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
|
|
print(base64_string[:100])
|
|
return {
|
|
"minLon": minx,
|
|
"minLat": miny,
|
|
"maxLon": maxx,
|
|
"maxLat": maxy,
|
|
"image": base64_string
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent
|
|
PIC_GPS_SCRIPT = BASE_DIR / "pic_gps.py"
|
|
|
|
@app.post("/stitch")
|
|
async def stitch(
|
|
files: List[UploadFile] = File(..., description="합성할 이미지 파일들 (2장 이상)"),
|
|
fileId: str = Form(...)
|
|
):
|
|
if len(files) < 2:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="최소 2장 이상의 이미지가 필요합니다."
|
|
)
|
|
|
|
try:
|
|
today = datetime.now().strftime("%Y%m%d")
|
|
upload_dir = BASE_DIR / "stitch" / fileId
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model_list = []
|
|
for idx, file in enumerate(files):
|
|
|
|
model = check_camera_info(file.file)
|
|
model_list.append(model)
|
|
|
|
original_filename = file.filename or f"image_{idx}.jpg"
|
|
filename = f"{model}_{idx:03d}_{original_filename}"
|
|
file_path = upload_dir / filename
|
|
|
|
output_filename = f"stitched_{fileId}.jpg"
|
|
output_path = upload_dir / output_filename
|
|
|
|
# 파일 저장
|
|
with open(file_path, "wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
|
|
model_counter = Counter(model_list)
|
|
most_common_model = model_counter.most_common(1)
|
|
|
|
cmd = [
|
|
"python",
|
|
str(PIC_GPS_SCRIPT),
|
|
"--mode", "drone",
|
|
"--input", str(upload_dir),
|
|
"--out", str(output_path),
|
|
"--model", most_common_model[0][0],
|
|
"--enhance"
|
|
]
|
|
|
|
print(cmd)
|
|
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=300
|
|
)
|
|
|
|
print(f"Subprocess stdout: {result.stdout}")
|
|
if result.returncode != 0:
|
|
print(f"Subprocess stderr: {result.stderr}")
|
|
raise HTTPException(status_code=500, detail=f"Script failed: {result.stderr}")
|
|
|
|
return FileResponse(
|
|
path=str(output_path),
|
|
media_type="image/jpeg",
|
|
filename=output_filename
|
|
)
|
|
except HTTPException:
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=5001)
|