iran prediction 47개 Python 파일을 prediction/ 디렉토리로 복제: - algorithms/ 14개 분석 알고리즘 (어구추론, 다크베셀, 스푸핑, 환적, 위험도 등) - pipeline/ 7단계 분류 파이프라인 - cache/vessel_store (24h 슬라이딩 윈도우) - db/ 어댑터 (snpdb 원본조회, kcgdb 결과저장) - chat/ AI 채팅 (Ollama, 후순위) - data/ 정적 데이터 (기선, 특정어업수역 GeoJSON) config.py를 kcgaidb로 재구성 (DB명, 사용자, 비밀번호) DB 연결 검증 완료 (kcgaidb 37개 테이블 접근 확인) Makefile에 dev-prediction / dev-all 타겟 추가 CLAUDE.md에 prediction 섹션 추가 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
237 lines
8.5 KiB
Python
237 lines
8.5 KiB
Python
"""AI 해양분석 채팅 엔드포인트 — 사전 쿼리 + SSE 스트리밍 + Tool Calling."""
|
|
|
|
import json
|
|
import logging
|
|
|
|
import httpx
|
|
from fastapi import APIRouter
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from chat.cache import load_chat_history, save_chat_history, clear_chat_history
|
|
from chat.context_builder import MaritimeContextBuilder
|
|
from chat.tools import detect_prequery, execute_prequery, parse_tool_calls, execute_tool_call
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix='/api/v1/chat', tags=['chat'])
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
message: str
|
|
user_id: str = 'anonymous'
|
|
stream: bool = True
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
role: str = 'assistant'
|
|
content: str
|
|
|
|
|
|
@router.post('')
|
|
async def chat(req: ChatRequest):
|
|
"""해양분석 채팅 — 사전 쿼리 + 분석 컨텍스트 + Ollama SSE 스트리밍."""
|
|
history = load_chat_history(req.user_id)
|
|
|
|
builder = MaritimeContextBuilder()
|
|
system_prompt = builder.build_system_prompt(user_message=req.message)
|
|
|
|
# ── 사전 쿼리: 키워드 패턴 매칭으로 DB 조회 후 컨텍스트 보강 ──
|
|
prequery_params = detect_prequery(req.message)
|
|
prequery_result = ''
|
|
if prequery_params:
|
|
prequery_result = execute_prequery(prequery_params)
|
|
logger.info('prequery: params=%s, results=%d chars', prequery_params, len(prequery_result))
|
|
|
|
# 시스템 프롬프트에 사전 쿼리 결과 추가
|
|
if prequery_result:
|
|
system_prompt += '\n\n' + prequery_result
|
|
|
|
messages = [
|
|
{'role': 'system', 'content': system_prompt},
|
|
*history[-10:],
|
|
{'role': 'user', 'content': req.message},
|
|
]
|
|
|
|
ollama_payload = {
|
|
'model': settings.OLLAMA_MODEL,
|
|
'messages': messages,
|
|
'stream': req.stream,
|
|
'options': {
|
|
'temperature': 0.3,
|
|
'num_predict': 1024,
|
|
'num_ctx': 2048,
|
|
},
|
|
}
|
|
|
|
if req.stream:
|
|
return StreamingResponse(
|
|
_stream_with_tools(ollama_payload, req.user_id, history, req.message),
|
|
media_type='text/event-stream',
|
|
headers={
|
|
'Cache-Control': 'no-cache',
|
|
'Connection': 'keep-alive',
|
|
'X-Accel-Buffering': 'no',
|
|
},
|
|
)
|
|
|
|
return await _call_with_tools(ollama_payload, req.user_id, history, req.message)
|
|
|
|
|
|
async def _stream_with_tools(payload: dict, user_id: str, history: list[dict], user_message: str):
|
|
"""SSE 스트리밍 — 1차 응답 후 Tool Call 감지 시 2차 호출."""
|
|
accumulated = ''
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(settings.OLLAMA_TIMEOUT_SEC)) as client:
|
|
# 1차 LLM 호출
|
|
async with client.stream(
|
|
'POST',
|
|
f'{settings.OLLAMA_BASE_URL}/api/chat',
|
|
json=payload,
|
|
) as response:
|
|
async for line in response.aiter_lines():
|
|
if not line:
|
|
continue
|
|
try:
|
|
chunk = json.loads(line)
|
|
content = chunk.get('message', {}).get('content', '')
|
|
done = chunk.get('done', False)
|
|
accumulated += content
|
|
|
|
sse_data = json.dumps({
|
|
'content': content,
|
|
'done': False, # 아직 done 보내지 않음 (tool call 가능)
|
|
}, ensure_ascii=False)
|
|
yield f'data: {sse_data}\n\n'
|
|
|
|
if done:
|
|
break
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Tool Call 감지
|
|
tool_calls = parse_tool_calls(accumulated)
|
|
if tool_calls:
|
|
# Tool 실행
|
|
tool_results = []
|
|
for tc in tool_calls:
|
|
result = execute_tool_call(tc)
|
|
tool_results.append(result)
|
|
logger.info('tool call: %s → %d chars', tc.get('tool'), len(result))
|
|
|
|
tool_context = '\n'.join(tool_results)
|
|
|
|
# 2차 LLM 호출 (tool 결과 포함)
|
|
payload['messages'].append({'role': 'assistant', 'content': accumulated})
|
|
payload['messages'].append({
|
|
'role': 'user',
|
|
'content': f'도구 조회 결과입니다. 이 데이터를 기반으로 사용자 질문에 답변하세요:\n{tool_context}',
|
|
})
|
|
|
|
# 구분자 전송
|
|
separator = json.dumps({'content': '\n\n---\n_데이터 조회 완료. 분석 결과:_\n\n', 'done': False}, ensure_ascii=False)
|
|
yield f'data: {separator}\n\n'
|
|
|
|
accumulated_2 = ''
|
|
async with client.stream(
|
|
'POST',
|
|
f'{settings.OLLAMA_BASE_URL}/api/chat',
|
|
json=payload,
|
|
) as response2:
|
|
async for line in response2.aiter_lines():
|
|
if not line:
|
|
continue
|
|
try:
|
|
chunk = json.loads(line)
|
|
content = chunk.get('message', {}).get('content', '')
|
|
done = chunk.get('done', False)
|
|
accumulated_2 += content
|
|
|
|
sse_data = json.dumps({
|
|
'content': content,
|
|
'done': done,
|
|
}, ensure_ascii=False)
|
|
yield f'data: {sse_data}\n\n'
|
|
|
|
if done:
|
|
break
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# 히스토리에는 최종 답변만 저장
|
|
accumulated = accumulated_2 or accumulated
|
|
|
|
except httpx.TimeoutException:
|
|
err_msg = json.dumps({'content': '\n\n[응답 시간 초과]', 'done': True})
|
|
yield f'data: {err_msg}\n\n'
|
|
except Exception as e:
|
|
logger.error('ollama stream error: %s', e)
|
|
err_msg = json.dumps({'content': f'[오류: {e}]', 'done': True})
|
|
yield f'data: {err_msg}\n\n'
|
|
|
|
if accumulated:
|
|
updated = history + [
|
|
{'role': 'user', 'content': user_message},
|
|
{'role': 'assistant', 'content': accumulated},
|
|
]
|
|
save_chat_history(user_id, updated)
|
|
|
|
yield 'data: [DONE]\n\n'
|
|
|
|
|
|
async def _call_with_tools(
|
|
payload: dict, user_id: str, history: list[dict], user_message: str,
|
|
) -> ChatResponse:
|
|
"""비스트리밍 — Tool Calling 포함."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(settings.OLLAMA_TIMEOUT_SEC)) as client:
|
|
# 1차 호출
|
|
response = await client.post(
|
|
f'{settings.OLLAMA_BASE_URL}/api/chat',
|
|
json=payload,
|
|
)
|
|
data = response.json()
|
|
content = data.get('message', {}).get('content', '')
|
|
|
|
# Tool Call 감지
|
|
tool_calls = parse_tool_calls(content)
|
|
if tool_calls:
|
|
tool_results = [execute_tool_call(tc) for tc in tool_calls]
|
|
tool_context = '\n'.join(tool_results)
|
|
|
|
payload['messages'].append({'role': 'assistant', 'content': content})
|
|
payload['messages'].append({
|
|
'role': 'user',
|
|
'content': f'도구 조회 결과입니다. 이 데이터를 기반으로 답변하세요:\n{tool_context}',
|
|
})
|
|
|
|
response2 = await client.post(
|
|
f'{settings.OLLAMA_BASE_URL}/api/chat',
|
|
json=payload,
|
|
)
|
|
data2 = response2.json()
|
|
content = data2.get('message', {}).get('content', content)
|
|
|
|
updated = history + [
|
|
{'role': 'user', 'content': user_message},
|
|
{'role': 'assistant', 'content': content},
|
|
]
|
|
save_chat_history(user_id, updated)
|
|
|
|
return ChatResponse(content=content)
|
|
except Exception as e:
|
|
logger.error('ollama sync error: %s', e)
|
|
return ChatResponse(content=f'오류: AI 서버 연결 실패 ({e})')
|
|
|
|
|
|
@router.get('/history')
|
|
async def get_history(user_id: str = 'anonymous'):
|
|
return load_chat_history(user_id)
|
|
|
|
|
|
@router.delete('/history')
|
|
async def delete_history(user_id: str = 'anonymous'):
|
|
clear_chat_history(user_id)
|
|
return {'ok': True}
|