Merge pull request 'feat: ���� ����, ������ �г� �߰�, ������ ���� ����' (#97) from feature/prediction into develop
This commit is contained in:
커밋
63cf614365
@ -448,7 +448,7 @@ const ALGO_CD_TO_MODEL: Record<string, string> = {
|
||||
'POSEIDON': 'POSEIDON',
|
||||
};
|
||||
|
||||
interface TrajectoryResult {
|
||||
interface SingleModelTrajectoryResult {
|
||||
trajectory: Array<{ lat: number; lon: number; time: number; particle: number; stranded?: 0 | 1; model: string }>;
|
||||
summary: {
|
||||
remainingVolume: number;
|
||||
@ -462,7 +462,22 @@ interface TrajectoryResult {
|
||||
hydrData: ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[];
|
||||
}
|
||||
|
||||
function transformTrajectoryResult(rawResult: TrajectoryTimeStep[], model: string): TrajectoryResult {
|
||||
interface TrajectoryResult {
|
||||
trajectory: Array<{ lat: number; lon: number; time: number; particle: number; stranded?: 0 | 1; model: string }>;
|
||||
summary: {
|
||||
remainingVolume: number;
|
||||
weatheredVolume: number;
|
||||
pollutionArea: number;
|
||||
beachedVolume: number;
|
||||
pollutionCoastLength: number;
|
||||
};
|
||||
centerPoints: Array<{ lat: number; lon: number; time: number; model: string }>;
|
||||
windDataByModel: Record<string, TrajectoryWindPoint[][]>;
|
||||
hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]>;
|
||||
summaryByModel: Record<string, SingleModelTrajectoryResult['summary']>;
|
||||
}
|
||||
|
||||
function transformTrajectoryResult(rawResult: TrajectoryTimeStep[], model: string): SingleModelTrajectoryResult {
|
||||
const trajectory = rawResult.flatMap((step, stepIdx) =>
|
||||
step.particles.map((p, i) => ({
|
||||
lat: p.lat,
|
||||
@ -513,8 +528,11 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
|
||||
let mergedTrajectory: TrajectoryResult['trajectory'] = [];
|
||||
let allCenterPoints: TrajectoryResult['centerPoints'] = [];
|
||||
|
||||
// summary/windData/hydrData: 가장 최근 완료된 OpenDrift 기준, 없으면 POSEIDON 기준
|
||||
let baseResult: TrajectoryResult | null = null;
|
||||
// summary: 가장 최근 완료된 OpenDrift 기준, 없으면 POSEIDON 기준
|
||||
let baseResult: SingleModelTrajectoryResult | null = null;
|
||||
const windDataByModel: Record<string, TrajectoryWindPoint[][]> = {};
|
||||
const hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]> = {};
|
||||
const summaryByModel: Record<string, SingleModelTrajectoryResult['summary']> = {};
|
||||
|
||||
// OpenDrift 우선, 없으면 POSEIDON 선택 (ORDER BY CMPL_DTM DESC이므로 첫 번째 행이 가장 최근)
|
||||
const opendriftRow = (rows as Array<Record<string, unknown>>).find((r) => r['algo_cd'] === 'OPENDRIFT');
|
||||
@ -528,8 +546,10 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
|
||||
const parsed = transformTrajectoryResult(row['rslt_data'] as TrajectoryTimeStep[], modelName);
|
||||
mergedTrajectory = mergedTrajectory.concat(parsed.trajectory);
|
||||
allCenterPoints = allCenterPoints.concat(parsed.centerPoints);
|
||||
windDataByModel[modelName] = parsed.windData;
|
||||
hydrDataByModel[modelName] = parsed.hydrData;
|
||||
summaryByModel[modelName] = parsed.summary;
|
||||
|
||||
// 기준 행의 결과를 baseResult로 사용
|
||||
if (row === baseRow) {
|
||||
baseResult = parsed;
|
||||
}
|
||||
@ -541,8 +561,9 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
|
||||
trajectory: mergedTrajectory,
|
||||
summary: baseResult.summary,
|
||||
centerPoints: allCenterPoints,
|
||||
windData: baseResult.windData,
|
||||
hydrData: baseResult.hydrData,
|
||||
windDataByModel,
|
||||
hydrDataByModel,
|
||||
summaryByModel,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -425,6 +425,301 @@ router.get('/status/:execSn', requireAuth, async (req: Request, res: Response) =
|
||||
}
|
||||
})
|
||||
|
||||
// ============================================================
|
||||
// POST /api/simulation/run-model (동기 방식)
|
||||
// 예측 완료 후 결과를 직접 반환한다.
|
||||
// ============================================================
|
||||
/**
|
||||
* 선택된 모델로 확산 시뮬레이션을 실행하고 완료될 때까지 대기한 후 결과를 반환한다.
|
||||
* 다중 모델은 병렬로 실행되며, 일부 모델 실패 시 성공한 모델 결과는 포함된다.
|
||||
*/
|
||||
router.post('/run-model', requireAuth, async (req: Request, res: Response) => {
|
||||
try {
|
||||
const { acdntSn: rawAcdntSn, acdntNm, spillUnit, spillTypeCd,
|
||||
lat, lon, runTime, matTy, matVol, spillTime, startTime,
|
||||
models: rawModels } = req.body
|
||||
|
||||
let requestedModels: string[] = Array.isArray(rawModels) && rawModels.length > 0
|
||||
? (rawModels as string[])
|
||||
: ['OpenDrift']
|
||||
|
||||
// 1. 필수 파라미터 검증
|
||||
if (lat === undefined || lon === undefined || runTime === undefined) {
|
||||
return res.status(400).json({ error: '필수 파라미터 누락', required: ['lat', 'lon', 'runTime'] })
|
||||
}
|
||||
if (!isValidLatitude(lat)) {
|
||||
return res.status(400).json({ error: '유효하지 않은 위도', message: '위도는 -90~90 범위여야 합니다.' })
|
||||
}
|
||||
if (!isValidLongitude(lon)) {
|
||||
return res.status(400).json({ error: '유효하지 않은 경도', message: '경도는 -180~180 범위여야 합니다.' })
|
||||
}
|
||||
if (!isValidNumber(runTime, 1, 720)) {
|
||||
return res.status(400).json({ error: '유효하지 않은 예측 시간', message: '예측 시간은 1~720 범위여야 합니다.' })
|
||||
}
|
||||
if (matVol !== undefined && !isValidNumber(matVol, 0, 1000000)) {
|
||||
return res.status(400).json({ error: '유효하지 않은 유출량' })
|
||||
}
|
||||
if (matTy !== undefined && (typeof matTy !== 'string' || !isValidStringLength(matTy, 50))) {
|
||||
return res.status(400).json({ error: '유효하지 않은 유종' })
|
||||
}
|
||||
if (!rawAcdntSn && (!acdntNm || typeof acdntNm !== 'string' || !acdntNm.trim())) {
|
||||
return res.status(400).json({ error: '사고를 선택하거나 사고명을 입력해야 합니다.' })
|
||||
}
|
||||
if (acdntNm && (typeof acdntNm !== 'string' || !isValidStringLength(acdntNm, 200))) {
|
||||
return res.status(400).json({ error: '사고명은 200자 이내여야 합니다.' })
|
||||
}
|
||||
|
||||
// 2. NC 파일 존재 여부 확인
|
||||
if (requestedModels.includes('OpenDrift')) {
|
||||
try {
|
||||
const checkRes = await fetch(`${PYTHON_API_URL}/check-nc`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ lat, lon, startTime }),
|
||||
signal: AbortSignal.timeout(5000),
|
||||
})
|
||||
if (!checkRes.ok) {
|
||||
// NC 파일 없으면 OpenDrift만 제외, 나머지 모델(POSEIDON 등)은 계속 진행
|
||||
requestedModels = requestedModels.filter(m => m !== 'OpenDrift')
|
||||
if (requestedModels.length === 0) {
|
||||
return res.status(409).json({
|
||||
error: '해당 좌표의 해양 기상 데이터가 없습니다.',
|
||||
message: 'NC 파일이 준비되지 않았습니다.',
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Python 서버 미기동 — 이후 단계에서 처리
|
||||
}
|
||||
}
|
||||
|
||||
// 3. ACDNT/SPIL_DATA 생성 또는 조회
|
||||
let resolvedAcdntSn: number | null = rawAcdntSn ? Number(rawAcdntSn) : null
|
||||
let resolvedSpilDataSn: number | null = null
|
||||
let newlyCreatedAcdntSn: number | null = null
|
||||
let newlyCreatedSpilDataSn: number | null = null
|
||||
|
||||
if (!resolvedAcdntSn && acdntNm) {
|
||||
try {
|
||||
const occrn = startTime ?? new Date().toISOString()
|
||||
const acdntRes = await wingPool.query(
|
||||
`INSERT INTO wing.ACDNT
|
||||
(ACDNT_CD, ACDNT_NM, ACDNT_TP_CD, OCCRN_DTM, LAT, LNG, ACDNT_STTS_CD, USE_YN, REG_DTM)
|
||||
VALUES (
|
||||
'INC-' || EXTRACT(YEAR FROM NOW())::TEXT || '-' ||
|
||||
LPAD(
|
||||
(SELECT COALESCE(MAX(CAST(SPLIT_PART(ACDNT_CD, '-', 3) AS INTEGER)), 0) + 1
|
||||
FROM wing.ACDNT
|
||||
WHERE ACDNT_CD LIKE 'INC-' || EXTRACT(YEAR FROM NOW())::TEXT || '-%')::TEXT,
|
||||
4, '0'
|
||||
),
|
||||
$1, '유류유출', $2, $3, $4, 'ACTIVE', 'Y', NOW()
|
||||
)
|
||||
RETURNING ACDNT_SN`,
|
||||
[acdntNm.trim(), occrn, lat, lon]
|
||||
)
|
||||
resolvedAcdntSn = acdntRes.rows[0].acdnt_sn as number
|
||||
newlyCreatedAcdntSn = resolvedAcdntSn
|
||||
|
||||
const spilRes = await wingPool.query(
|
||||
`INSERT INTO wing.SPIL_DATA (ACDNT_SN, OIL_TP_CD, SPIL_QTY, SPIL_UNIT_CD, SPIL_TP_CD, FCST_HR, REG_DTM)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NOW())
|
||||
RETURNING SPIL_DATA_SN`,
|
||||
[
|
||||
resolvedAcdntSn,
|
||||
OIL_DB_CODE_MAP[matTy as string] ?? 'BUNKER_C',
|
||||
matVol ?? 0,
|
||||
UNIT_MAP[spillUnit as string] ?? 'KL',
|
||||
SPIL_TYPE_MAP[spillTypeCd as string] ?? 'CONTINUOUS',
|
||||
runTime,
|
||||
]
|
||||
)
|
||||
resolvedSpilDataSn = spilRes.rows[0].spil_data_sn as number
|
||||
newlyCreatedSpilDataSn = resolvedSpilDataSn
|
||||
} catch (dbErr) {
|
||||
console.error('[simulation/run-model] ACDNT/SPIL_DATA INSERT 실패:', dbErr)
|
||||
return res.status(500).json({ error: '사고 정보 생성 실패' })
|
||||
}
|
||||
}
|
||||
|
||||
if (resolvedAcdntSn && !resolvedSpilDataSn) {
|
||||
try {
|
||||
const spilRes = await wingPool.query(
|
||||
`SELECT SPIL_DATA_SN FROM wing.SPIL_DATA WHERE ACDNT_SN = $1 ORDER BY SPIL_DATA_SN DESC LIMIT 1`,
|
||||
[resolvedAcdntSn]
|
||||
)
|
||||
if (spilRes.rows.length > 0) {
|
||||
resolvedSpilDataSn = spilRes.rows[0].spil_data_sn as number
|
||||
}
|
||||
} catch (dbErr) {
|
||||
console.error('[simulation/run-model] SPIL_DATA 조회 실패:', dbErr)
|
||||
}
|
||||
}
|
||||
|
||||
const odMatTy = matTy !== undefined ? (OIL_TYPE_MAP[matTy as string] ?? (matTy as string)) : undefined
|
||||
const execNmBase = `EXPC_${Date.now()}`
|
||||
|
||||
// KOSPS: PRED_EXEC INSERT(PENDING)만 수행
|
||||
const execSns: Array<{ model: string; execSn: number }> = []
|
||||
if (requestedModels.includes('KOSPS')) {
|
||||
try {
|
||||
const kospsExecNm = `${execNmBase}_KOSPS`
|
||||
const insertRes = await wingPool.query(
|
||||
`INSERT INTO wing.PRED_EXEC (ACDNT_SN, SPIL_DATA_SN, ALGO_CD, EXEC_STTS_CD, EXEC_NM, BGNG_DTM)
|
||||
VALUES ($1, $2, 'KOSPS', 'PENDING', $3, NOW())
|
||||
RETURNING PRED_EXEC_SN`,
|
||||
[resolvedAcdntSn, resolvedSpilDataSn, kospsExecNm]
|
||||
)
|
||||
execSns.push({ model: 'KOSPS', execSn: insertRes.rows[0].pred_exec_sn as number })
|
||||
} catch (dbErr) {
|
||||
console.error('[simulation/run-model] KOSPS PRED_EXEC INSERT 실패:', dbErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. API 연동 모델 시작 및 완료 대기 (병렬)
|
||||
const apiModels = requestedModels.filter((m) => m !== 'KOSPS' && MODEL_ALGO_CD_MAP[m] !== undefined)
|
||||
|
||||
interface SyncModelResult {
|
||||
model: string
|
||||
execSn: number
|
||||
status: 'DONE' | 'ERROR'
|
||||
trajectory?: ReturnType<typeof transformResult>['trajectory']
|
||||
summary?: ReturnType<typeof transformResult>['summary']
|
||||
centerPoints?: ReturnType<typeof transformResult>['centerPoints']
|
||||
windData?: ReturnType<typeof transformResult>['windData']
|
||||
hydrData?: ReturnType<typeof transformResult>['hydrData']
|
||||
error?: string
|
||||
}
|
||||
|
||||
const modelResults = await Promise.all(
|
||||
apiModels.map(async (model): Promise<SyncModelResult> => {
|
||||
const algoCd = MODEL_ALGO_CD_MAP[model]
|
||||
const apiUrl = MODEL_API_URL_MAP[model]
|
||||
const execNm = `${execNmBase}_${algoCd}`
|
||||
|
||||
// PRED_EXEC INSERT
|
||||
let predExecSn: number
|
||||
try {
|
||||
const insertRes = await wingPool.query(
|
||||
`INSERT INTO wing.PRED_EXEC (ACDNT_SN, SPIL_DATA_SN, ALGO_CD, EXEC_STTS_CD, EXEC_NM, BGNG_DTM)
|
||||
VALUES ($1, $2, $3, 'PENDING', $4, NOW())
|
||||
RETURNING PRED_EXEC_SN`,
|
||||
[resolvedAcdntSn, resolvedSpilDataSn, algoCd, execNm]
|
||||
)
|
||||
predExecSn = insertRes.rows[0].pred_exec_sn as number
|
||||
} catch (dbErr) {
|
||||
console.error(`[simulation/run-model] ${model} PRED_EXEC INSERT 실패:`, dbErr)
|
||||
return { model, execSn: 0, status: 'ERROR', error: 'DB 오류' }
|
||||
}
|
||||
|
||||
execSns.push({ model, execSn: predExecSn })
|
||||
|
||||
// Python /run-model 호출
|
||||
let jobId: string | undefined
|
||||
try {
|
||||
const pythonRes = await fetch(`${apiUrl}/run-model`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ lat, lon, startTime, runTime, matTy: odMatTy, matVol, spillTime, name: execNm }),
|
||||
signal: AbortSignal.timeout(POLL_TIMEOUT_MS),
|
||||
})
|
||||
|
||||
if (pythonRes.status === 503) {
|
||||
const errData = await pythonRes.json() as { error?: string }
|
||||
const errMsg = errData.error || '분석 서버 포화'
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='FAILED', ERR_MSG=$1, CMPL_DTM=NOW() WHERE PRED_EXEC_SN=$2`,
|
||||
[errMsg, predExecSn]
|
||||
)
|
||||
return { model, execSn: predExecSn, status: 'ERROR', error: errMsg }
|
||||
}
|
||||
|
||||
if (!pythonRes.ok) {
|
||||
throw new Error(`Python 서버 응답 오류: ${pythonRes.status}`)
|
||||
}
|
||||
|
||||
const pythonData = await pythonRes.json() as {
|
||||
success?: boolean;
|
||||
result?: PythonTimeStep[];
|
||||
job_id?: string;
|
||||
error?: string;
|
||||
message?: string;
|
||||
error_code?: number;
|
||||
}
|
||||
|
||||
// 동기 성공 응답 (OpenDrift & POSEIDON 공통)
|
||||
if (Array.isArray(pythonData.result)) {
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC
|
||||
SET EXEC_STTS_CD='COMPLETED', RSLT_DATA=$1,
|
||||
CMPL_DTM=NOW(), REQD_SEC=EXTRACT(EPOCH FROM (NOW() - BGNG_DTM))::INTEGER
|
||||
WHERE PRED_EXEC_SN=$2`,
|
||||
[JSON.stringify(pythonData.result), predExecSn]
|
||||
)
|
||||
const { trajectory, summary, centerPoints, windData, hydrData } =
|
||||
transformResult(pythonData.result, model)
|
||||
return { model, execSn: predExecSn, status: 'DONE', trajectory, summary, centerPoints, windData, hydrData }
|
||||
}
|
||||
|
||||
// 비동기 응답 (하위 호환)
|
||||
if (pythonData.job_id) {
|
||||
jobId = pythonData.job_id
|
||||
} else {
|
||||
// 오류 응답 (success: false, HTTP 200)
|
||||
const errMsg = pythonData.error || pythonData.message || '분석 오류'
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='FAILED', ERR_MSG=$1, CMPL_DTM=NOW() WHERE PRED_EXEC_SN=$2`,
|
||||
[errMsg, predExecSn]
|
||||
)
|
||||
return { model, execSn: predExecSn, status: 'ERROR', error: errMsg }
|
||||
}
|
||||
} catch (fetchErr) {
|
||||
const errMsg = 'Python 분석 서버에 연결할 수 없습니다.'
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='FAILED', ERR_MSG=$1, CMPL_DTM=NOW() WHERE PRED_EXEC_SN=$2`,
|
||||
[errMsg, predExecSn]
|
||||
)
|
||||
return { model, execSn: predExecSn, status: 'ERROR', error: errMsg }
|
||||
}
|
||||
|
||||
// RUNNING 업데이트 (비동기 폴링 경로)
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='RUNNING' WHERE PRED_EXEC_SN=$1`,
|
||||
[predExecSn]
|
||||
)
|
||||
|
||||
// 결과 동기 대기
|
||||
try {
|
||||
const rawResult = await runModelSync(jobId!, predExecSn, apiUrl)
|
||||
const { trajectory, summary, centerPoints, windData, hydrData } = transformResult(rawResult, model)
|
||||
return { model, execSn: predExecSn, status: 'DONE', trajectory, summary, centerPoints, windData, hydrData }
|
||||
} catch (syncErr) {
|
||||
return { model, execSn: predExecSn, status: 'ERROR', error: (syncErr as Error).message }
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 모든 모델이 실패하고 신규 생성한 ACDNT가 있으면 롤백
|
||||
const hasSuccess = modelResults.some((r) => r.status === 'DONE')
|
||||
if (!hasSuccess && newlyCreatedAcdntSn !== null) {
|
||||
for (const r of modelResults) {
|
||||
if (r.execSn) await rollbackNewRecords(r.execSn, null, null)
|
||||
}
|
||||
await rollbackNewRecords(null, newlyCreatedSpilDataSn, newlyCreatedAcdntSn)
|
||||
return res.status(503).json({ error: '분석 서버에 연결할 수 없습니다.' })
|
||||
}
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
acdntSn: resolvedAcdntSn,
|
||||
execSns: [...execSns, ...modelResults.map(({ model, execSn }) => ({ model, execSn }))],
|
||||
results: modelResults,
|
||||
})
|
||||
} catch {
|
||||
res.status(500).json({ error: '시뮬레이션 실행 실패', message: '서버 내부 오류가 발생했습니다.' })
|
||||
}
|
||||
})
|
||||
|
||||
// ============================================================
|
||||
// 백그라운드 폴링
|
||||
// ============================================================
|
||||
@ -474,6 +769,57 @@ async function pollAndSaveModel(jobId: string, execSn: number, apiUrl: string, a
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 동기 폴링: Python 결과 대기 후 반환
|
||||
// ============================================================
|
||||
async function runModelSync(jobId: string, execSn: number, apiUrl: string): Promise<PythonTimeStep[]> {
|
||||
const deadline = Date.now() + POLL_TIMEOUT_MS
|
||||
|
||||
while (Date.now() < deadline) {
|
||||
await new Promise<void>(resolve => setTimeout(resolve, POLL_INTERVAL_MS))
|
||||
|
||||
let data: PythonStatusResponse
|
||||
try {
|
||||
const pollRes = await fetch(`${apiUrl}/status/${jobId}`, {
|
||||
signal: AbortSignal.timeout(5000),
|
||||
})
|
||||
if (!pollRes.ok) continue
|
||||
data = await pollRes.json() as PythonStatusResponse
|
||||
} catch {
|
||||
// 네트워크 오류 — 재시도
|
||||
continue
|
||||
}
|
||||
|
||||
if (data.status === 'DONE' && data.result) {
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC
|
||||
SET EXEC_STTS_CD='COMPLETED',
|
||||
RSLT_DATA=$1,
|
||||
CMPL_DTM=NOW(),
|
||||
REQD_SEC=EXTRACT(EPOCH FROM (NOW() - BGNG_DTM))::INTEGER
|
||||
WHERE PRED_EXEC_SN=$2`,
|
||||
[JSON.stringify(data.result), execSn]
|
||||
)
|
||||
return data.result
|
||||
}
|
||||
|
||||
if (data.status === 'ERROR') {
|
||||
const errMsg = data.error ?? '분석 오류'
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='FAILED', ERR_MSG=$1, CMPL_DTM=NOW() WHERE PRED_EXEC_SN=$2`,
|
||||
[errMsg, execSn]
|
||||
)
|
||||
throw new Error(errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
await wingPool.query(
|
||||
`UPDATE wing.PRED_EXEC SET EXEC_STTS_CD='FAILED', ERR_MSG='분석 시간 초과 (30분)', CMPL_DTM=NOW() WHERE PRED_EXEC_SN=$1`,
|
||||
[execSn]
|
||||
)
|
||||
throw new Error('분석 시간 초과 (30분)')
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 타입 및 결과 변환
|
||||
// ============================================================
|
||||
|
||||
@ -93,6 +93,7 @@ const DEFAULT_MENU_CONFIG: MenuConfigItem[] = [
|
||||
{ id: 'board', label: '게시판', icon: '📌', enabled: true, order: 8 },
|
||||
{ id: 'weather', label: '기상정보', icon: '⛅', enabled: true, order: 9 },
|
||||
{ id: 'incidents', label: '통합조회', icon: '🔍', enabled: true, order: 10 },
|
||||
{ id: 'monitor', label: '실시간 상황관리', icon: '🛰', enabled: true, order: 11 },
|
||||
]
|
||||
|
||||
const VALID_MENU_IDS = DEFAULT_MENU_CONFIG.map(m => m.id)
|
||||
@ -103,18 +104,23 @@ export async function getMenuConfig(): Promise<MenuConfigItem[]> {
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(val) as MenuConfigItem[]
|
||||
const defaultMap = new Map(DEFAULT_MENU_CONFIG.map(m => [m.id, m]))
|
||||
|
||||
return parsed
|
||||
const dbMap = new Map(
|
||||
parsed
|
||||
.filter(item => VALID_MENU_IDS.includes(item.id))
|
||||
.map(item => {
|
||||
const defaults = defaultMap.get(item.id)!
|
||||
.map(item => [item.id, item])
|
||||
)
|
||||
|
||||
// DEFAULT 기준으로 머지 (DB에 없는 항목은 기본값 사용)
|
||||
return DEFAULT_MENU_CONFIG
|
||||
.map(defaultItem => {
|
||||
const dbItem = dbMap.get(defaultItem.id)
|
||||
if (!dbItem) return defaultItem
|
||||
return {
|
||||
id: item.id,
|
||||
label: item.label || defaults.label,
|
||||
icon: item.icon || defaults.icon,
|
||||
enabled: item.enabled,
|
||||
order: item.order,
|
||||
id: dbItem.id,
|
||||
label: dbItem.label || defaultItem.label,
|
||||
icon: dbItem.icon || defaultItem.icon,
|
||||
enabled: dbItem.enabled,
|
||||
order: dbItem.order,
|
||||
}
|
||||
})
|
||||
.sort((a, b) => a.order - b.order)
|
||||
|
||||
@ -278,7 +278,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
(1, 'incidents', 'READ', 'Y'), (1, 'incidents', 'CREATE', 'Y'), (1, 'incidents', 'UPDATE', 'Y'), (1, 'incidents', 'DELETE', 'Y'),
|
||||
(1, 'board', 'READ', 'Y'), (1, 'board', 'CREATE', 'Y'), (1, 'board', 'UPDATE', 'Y'), (1, 'board', 'DELETE', 'Y'),
|
||||
(1, 'weather', 'READ', 'Y'), (1, 'weather', 'CREATE', 'Y'), (1, 'weather', 'UPDATE', 'Y'), (1, 'weather', 'DELETE', 'Y'),
|
||||
(1, 'admin', 'READ', 'Y'), (1, 'admin', 'CREATE', 'Y'), (1, 'admin', 'UPDATE', 'Y'), (1, 'admin', 'DELETE', 'Y');
|
||||
(1, 'admin', 'READ', 'Y'), (1, 'admin', 'CREATE', 'Y'), (1, 'admin', 'UPDATE', 'Y'), (1, 'admin', 'DELETE', 'Y'),
|
||||
(1, 'monitor', 'READ', 'Y');
|
||||
|
||||
-- HQ_CLEANUP (ROLE_SN=2): 방제 관련 탭 RCUD + 기타 탭 READ/CREATE, admin 제외
|
||||
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
@ -292,7 +293,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
(2, 'incidents', 'READ', 'Y'), (2, 'incidents', 'CREATE', 'Y'), (2, 'incidents', 'UPDATE', 'Y'), (2, 'incidents', 'DELETE', 'Y'),
|
||||
(2, 'board', 'READ', 'Y'), (2, 'board', 'CREATE', 'Y'), (2, 'board', 'UPDATE', 'Y'),
|
||||
(2, 'weather', 'READ', 'Y'), (2, 'weather', 'CREATE', 'Y'),
|
||||
(2, 'admin', 'READ', 'N');
|
||||
(2, 'admin', 'READ', 'N'),
|
||||
(2, 'monitor', 'READ', 'Y');
|
||||
|
||||
-- MANAGER (ROLE_SN=3): admin 탭 제외, RCUD 허용
|
||||
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
@ -306,7 +308,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
(3, 'incidents', 'READ', 'Y'), (3, 'incidents', 'CREATE', 'Y'), (3, 'incidents', 'UPDATE', 'Y'), (3, 'incidents', 'DELETE', 'Y'),
|
||||
(3, 'board', 'READ', 'Y'), (3, 'board', 'CREATE', 'Y'), (3, 'board', 'UPDATE', 'Y'), (3, 'board', 'DELETE', 'Y'),
|
||||
(3, 'weather', 'READ', 'Y'), (3, 'weather', 'CREATE', 'Y'), (3, 'weather', 'UPDATE', 'Y'), (3, 'weather', 'DELETE', 'Y'),
|
||||
(3, 'admin', 'READ', 'N');
|
||||
(3, 'admin', 'READ', 'N'),
|
||||
(3, 'monitor', 'READ', 'Y');
|
||||
|
||||
-- USER (ROLE_SN=4): assets/admin 제외, 허용 탭은 READ/CREATE/UPDATE, DELETE 없음
|
||||
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
@ -320,7 +323,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
(4, 'incidents', 'READ', 'Y'), (4, 'incidents', 'CREATE', 'Y'), (4, 'incidents', 'UPDATE', 'Y'),
|
||||
(4, 'board', 'READ', 'Y'), (4, 'board', 'CREATE', 'Y'), (4, 'board', 'UPDATE', 'Y'),
|
||||
(4, 'weather', 'READ', 'Y'),
|
||||
(4, 'admin', 'READ', 'N');
|
||||
(4, 'admin', 'READ', 'N'),
|
||||
(4, 'monitor', 'READ', 'Y');
|
||||
|
||||
-- VIEWER (ROLE_SN=5): 제한적 탭의 READ만 허용
|
||||
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
@ -334,7 +338,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
|
||||
(5, 'incidents', 'READ', 'Y'),
|
||||
(5, 'board', 'READ', 'Y'),
|
||||
(5, 'weather', 'READ', 'Y'),
|
||||
(5, 'admin', 'READ', 'N');
|
||||
(5, 'admin', 'READ', 'N'),
|
||||
(5, 'monitor', 'READ', 'Y');
|
||||
|
||||
|
||||
-- ============================================================
|
||||
|
||||
@ -4,6 +4,20 @@
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### 추가
|
||||
- 관리자: 방제장비 현황 패널 (CleanupEquipPanel) — 관할청·유형별 필터, 자산 수량 조회
|
||||
- 관리자: 자산 현행화 업로드 패널 (AssetUploadPanel) — 엑셀/CSV 드래그 드롭 업로드
|
||||
|
||||
### 변경
|
||||
- trajectory API 모델별 windData/hydrData 분리 반환
|
||||
- 예측 서비스(predictionService) 개선
|
||||
- 보고서: 유출유 확산 지도 패널 및 보고서 생성기 개선
|
||||
- 관리자: 권한/메뉴 구성 업데이트, AdminView 패널 등록
|
||||
- prediction/image 이미지 분석 서버 분리 (디렉토리 제거)
|
||||
|
||||
### 기타
|
||||
- DB: monitor 권한 트리 마이그레이션(022) 추가, auth_init 갱신
|
||||
|
||||
## [2026-03-17]
|
||||
|
||||
### 추가
|
||||
|
||||
@ -97,6 +97,8 @@ function App() {
|
||||
return <AdminView />
|
||||
case 'rescue':
|
||||
return <RescueView />
|
||||
case 'monitor':
|
||||
return null
|
||||
default:
|
||||
return <div className="flex items-center justify-center h-full text-text-3">준비 중입니다...</div>
|
||||
}
|
||||
|
||||
@ -39,9 +39,13 @@ export function TopBar({ activeTab, onTabChange }: TopBarProps) {
|
||||
{/* Left Section */}
|
||||
<div className="flex items-center gap-4">
|
||||
{/* Logo */}
|
||||
<div className="flex items-center">
|
||||
<button
|
||||
onClick={() => tabs[0] && onTabChange(tabs[0].id as MainTab)}
|
||||
className="flex items-center hover:opacity-80 transition-opacity cursor-pointer"
|
||||
title="홈으로 이동"
|
||||
>
|
||||
<img src="/wing_logo_white.svg" alt="WING 해양환경 위기대응" className="h-3.5" />
|
||||
</div>
|
||||
</button>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="w-px h-6 bg-border-light" />
|
||||
@ -50,17 +54,28 @@ export function TopBar({ activeTab, onTabChange }: TopBarProps) {
|
||||
<div className="flex gap-0.5">
|
||||
{tabs.map((tab) => {
|
||||
const isIncident = tab.id === 'incidents'
|
||||
const isMonitor = tab.id === 'monitor'
|
||||
const handleClick = () => {
|
||||
if (isMonitor) {
|
||||
window.open(import.meta.env.VITE_SITUATIONAL_URL ?? 'https://kcg.gc-si.dev', '_blank')
|
||||
} else {
|
||||
onTabChange(tab.id as MainTab)
|
||||
}
|
||||
}
|
||||
return (
|
||||
<button
|
||||
key={tab.id}
|
||||
onClick={() => onTabChange(tab.id as MainTab)}
|
||||
onClick={handleClick}
|
||||
title={tab.label}
|
||||
className={`
|
||||
px-2.5 xl:px-4 py-2 rounded-sm text-[13px] transition-all duration-200
|
||||
font-korean tracking-[0.2px]
|
||||
${isIncident ? 'font-extrabold border-l border-l-[rgba(99,102,241,0.2)] ml-1' : 'font-semibold'}
|
||||
${isMonitor ? 'border-l border-l-[rgba(239,68,68,0.25)] ml-1 flex items-center gap-1.5' : ''}
|
||||
${
|
||||
activeTab === tab.id
|
||||
isMonitor
|
||||
? 'text-[#f87171] hover:text-[#fca5a5] hover:bg-[rgba(239,68,68,0.1)]'
|
||||
: activeTab === tab.id
|
||||
? isIncident
|
||||
? 'text-[#a5b4fc] bg-[rgba(99,102,241,0.18)] shadow-[0_0_8px_rgba(99,102,241,0.3)]'
|
||||
: 'text-[#22d3ee] bg-[rgba(6,182,212,0.15)] shadow-[0_0_8px_rgba(6,182,212,0.3)]'
|
||||
@ -70,30 +85,23 @@ export function TopBar({ activeTab, onTabChange }: TopBarProps) {
|
||||
}
|
||||
`}
|
||||
>
|
||||
{isMonitor ? (
|
||||
<>
|
||||
<span className="hidden xl:flex items-center gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-[#f87171] animate-pulse inline-block" />
|
||||
{tab.label}
|
||||
</span>
|
||||
<span className="xl:hidden text-[16px] leading-none">{tab.icon}</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<span className="xl:hidden text-[16px] leading-none">{tab.icon}</span>
|
||||
<span className="hidden xl:inline">{tab.label}</span>
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
|
||||
{/* 실시간 상황관리 */}
|
||||
<button
|
||||
onClick={() => window.open(import.meta.env.VITE_SITUATIONAL_URL ?? 'https://kcg.gc-si.dev', '_blank')}
|
||||
className={`
|
||||
px-2.5 xl:px-4 py-2 rounded-sm text-[13px] transition-all duration-200
|
||||
font-korean tracking-[0.2px] font-semibold
|
||||
border-l border-l-[rgba(239,68,68,0.25)] ml-1
|
||||
text-[#f87171] hover:text-[#fca5a5] hover:bg-[rgba(239,68,68,0.1)]
|
||||
flex items-center gap-1.5
|
||||
`}
|
||||
title="실시간 상황관리"
|
||||
>
|
||||
<span className="hidden xl:flex items-center gap-1.5">
|
||||
<span className="w-1.5 h-1.5 rounded-full bg-[#f87171] animate-pulse inline-block" />
|
||||
실시간 상황관리
|
||||
</span>
|
||||
<span className="xl:hidden text-[16px] leading-none">🛰</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@ -399,6 +399,36 @@ function FitBoundsController({ fitBoundsTarget }: { fitBoundsTarget?: { north: n
|
||||
return null
|
||||
}
|
||||
|
||||
// Map 중앙 좌표 + 줌 추적 컴포넌트 (Map 내부에서 useMap() 사용)
|
||||
function MapCenterTracker({
|
||||
onCenterChange,
|
||||
}: {
|
||||
onCenterChange: (lat: number, lng: number, zoom: number) => void;
|
||||
}) {
|
||||
const { current: map } = useMap()
|
||||
|
||||
useEffect(() => {
|
||||
if (!map) return
|
||||
|
||||
const update = () => {
|
||||
const center = map.getCenter()
|
||||
const zoom = map.getZoom()
|
||||
onCenterChange(center.lat, center.lng, zoom)
|
||||
}
|
||||
|
||||
update()
|
||||
map.on('move', update)
|
||||
map.on('zoom', update)
|
||||
|
||||
return () => {
|
||||
map.off('move', update)
|
||||
map.off('zoom', update)
|
||||
}
|
||||
}, [map, onCenterChange])
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// 3D 모드 pitch/bearing 제어 컴포넌트 (Map 내부에서 useMap() 사용)
|
||||
function MapPitchController({ threeD }: { threeD: boolean }) {
|
||||
const { current: map } = useMap()
|
||||
@ -519,12 +549,19 @@ export function MapView({
|
||||
const { mapToggles } = useMapStore()
|
||||
const isControlled = externalCurrentTime !== undefined
|
||||
const [currentPosition, setCurrentPosition] = useState<[number, number]>(DEFAULT_CENTER)
|
||||
const [mapCenter, setMapCenter] = useState<[number, number]>(DEFAULT_CENTER)
|
||||
const [mapZoom, setMapZoom] = useState<number>(DEFAULT_ZOOM)
|
||||
const [internalCurrentTime, setInternalCurrentTime] = useState(0)
|
||||
const [isPlaying, setIsPlaying] = useState(false)
|
||||
const [playbackSpeed, setPlaybackSpeed] = useState(1)
|
||||
const [popupInfo, setPopupInfo] = useState<PopupInfo | null>(null)
|
||||
const currentTime = isControlled ? externalCurrentTime : internalCurrentTime
|
||||
|
||||
const handleMapCenterChange = useCallback((lat: number, lng: number, zoom: number) => {
|
||||
setMapCenter([lat, lng])
|
||||
setMapZoom(zoom)
|
||||
}, [])
|
||||
|
||||
const handleMapClick = useCallback((e: MapLayerMouseEvent) => {
|
||||
const { lng, lat } = e.lngLat
|
||||
setCurrentPosition([lat, lng])
|
||||
@ -1207,6 +1244,8 @@ export function MapView({
|
||||
>
|
||||
{/* 지도 캡처 셋업 */}
|
||||
{mapCaptureRef && <MapCaptureSetup captureRef={mapCaptureRef} />}
|
||||
{/* 지도 중앙 좌표 + 줌 추적 */}
|
||||
<MapCenterTracker onCenterChange={handleMapCenterChange} />
|
||||
{/* 3D 모드 pitch 제어 */}
|
||||
<MapPitchController threeD={mapToggles.threeD} />
|
||||
{/* 사고 지점 변경 시 지도 이동 */}
|
||||
@ -1303,7 +1342,8 @@ export function MapView({
|
||||
|
||||
{/* 좌표 표시 */}
|
||||
{showOverlays && <CoordinateDisplay
|
||||
position={incidentCoord ? [incidentCoord.lat, incidentCoord.lon] : currentPosition}
|
||||
position={mapCenter}
|
||||
zoom={mapZoom}
|
||||
/>}
|
||||
|
||||
{/* 타임라인 컨트롤 (외부 제어 모드에서는 숨김 — 하단 플레이어가 대신 담당) */}
|
||||
@ -1499,16 +1539,23 @@ function MapLegend({ dispersionResult, incidentCoord, oilTrajectory = [], select
|
||||
}
|
||||
|
||||
// 좌표 표시
|
||||
function CoordinateDisplay({ position }: { position: [number, number] }) {
|
||||
function CoordinateDisplay({ position, zoom }: { position: [number, number]; zoom: number }) {
|
||||
const [lat, lng] = position
|
||||
const latDirection = lat >= 0 ? 'N' : 'S'
|
||||
const lngDirection = lng >= 0 ? 'E' : 'W'
|
||||
|
||||
// MapLibre 줌 → 축척 변환 (96 DPI 기준)
|
||||
const metersPerPixel = (40075016.686 * Math.cos((lat * Math.PI) / 180)) / (256 * Math.pow(2, zoom))
|
||||
const scaleRatio = Math.round(metersPerPixel * (96 / 0.0254))
|
||||
const scaleLabel = scaleRatio >= 1000000
|
||||
? `1:${(scaleRatio / 1000000).toFixed(1)}M`
|
||||
: `1:${scaleRatio.toLocaleString()}`
|
||||
|
||||
return (
|
||||
<div className="cod">
|
||||
<span>위도 <span className="cov">{Math.abs(lat).toFixed(4)}°{latDirection}</span></span>
|
||||
<span>경도 <span className="cov">{Math.abs(lng).toFixed(4)}°{lngDirection}</span></span>
|
||||
<span>축척 <span className="cov">1:50,000</span></span>
|
||||
<span>축척 <span className="cov">{scaleLabel}</span></span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@ -1585,7 +1632,11 @@ function TimelineControl({
|
||||
</div>
|
||||
<div className="tli">
|
||||
{/* eslint-disable-next-line react-hooks/purity */}
|
||||
<div className="tlct">+{currentTime.toFixed(0)}h — {new Date(Date.now() + currentTime * 3600000).toLocaleDateString('ko-KR', { month: '2-digit', day: '2-digit', hour: '2-digit', minute: '2-digit' })} KST</div>
|
||||
<div className="tlct">+{currentTime.toFixed(0)}h — {(() => {
|
||||
const base = simulationStartTime ? new Date(simulationStartTime) : new Date();
|
||||
const d = new Date(base.getTime() + currentTime * 3600 * 1000);
|
||||
return `${String(d.getMonth() + 1).padStart(2, '0')}/${String(d.getDate()).padStart(2, '0')} ${String(d.getHours()).padStart(2, '0')}:${String(d.getMinutes()).padStart(2, '0')} KST`;
|
||||
})()}</div>
|
||||
<div className="tlss">
|
||||
<div className="tls"><span className="tlsl">진행률</span><span className="tlsv">{progressPercent.toFixed(0)}%</span></div>
|
||||
<div className="tls"><span className="tlsl">속도</span><span className="tlsv">{playbackSpeed}×</span></div>
|
||||
|
||||
@ -1 +1 @@
|
||||
export type MainTab = 'prediction' | 'hns' | 'rescue' | 'reports' | 'aerial' | 'assets' | 'scat' | 'incidents' | 'board' | 'weather' | 'admin';
|
||||
export type MainTab = 'prediction' | 'hns' | 'rescue' | 'reports' | 'aerial' | 'assets' | 'scat' | 'incidents' | 'board' | 'weather' | 'monitor' | 'admin';
|
||||
|
||||
@ -8,6 +8,8 @@ import MenusPanel from './MenusPanel';
|
||||
import SettingsPanel from './SettingsPanel';
|
||||
import BoardMgmtPanel from './BoardMgmtPanel';
|
||||
import VesselSignalPanel from './VesselSignalPanel';
|
||||
import CleanupEquipPanel from './CleanupEquipPanel';
|
||||
import AssetUploadPanel from './AssetUploadPanel';
|
||||
|
||||
/** 기존 패널이 있는 메뉴 ID 매핑 */
|
||||
const PANEL_MAP: Record<string, () => JSX.Element> = {
|
||||
@ -19,6 +21,8 @@ const PANEL_MAP: Record<string, () => JSX.Element> = {
|
||||
board: () => <BoardMgmtPanel initialCategory="DATA" />,
|
||||
qna: () => <BoardMgmtPanel initialCategory="QNA" />,
|
||||
'collect-vessel-signal': () => <VesselSignalPanel />,
|
||||
'cleanup-equip': () => <CleanupEquipPanel />,
|
||||
'asset-upload': () => <AssetUploadPanel />,
|
||||
};
|
||||
|
||||
export function AdminView() {
|
||||
|
||||
257
frontend/src/tabs/admin/components/AssetUploadPanel.tsx
Normal file
257
frontend/src/tabs/admin/components/AssetUploadPanel.tsx
Normal file
@ -0,0 +1,257 @@
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { fetchUploadLogs } from '@tabs/assets/services/assetsApi';
|
||||
import type { UploadLogItem } from '@tabs/assets/services/assetsApi';
|
||||
|
||||
const ASSET_CATEGORIES = ['전체', '방제선', '유회수기', '이송펌프', '방제차량', '살포장치', '오일붐', '흡착재', '기타'];
|
||||
const JURISDICTIONS = ['전체', '남해청', '서해청', '중부청', '동해청', '제주청'];
|
||||
|
||||
const PERM_ITEMS = [
|
||||
{ icon: '👑', role: '시스템관리자', desc: '전체 자산 업로드/삭제 가능', bg: 'rgba(245,158,11,0.15)', color: 'text-yellow-400' },
|
||||
{ icon: '🔧', role: '운영관리자', desc: '관할청 내 자산 업로드 가능', bg: 'rgba(6,182,212,0.15)', color: 'text-primary-cyan' },
|
||||
{ icon: '👁', role: '조회자', desc: '현황 조회만 가능', bg: 'rgba(148,163,184,0.15)', color: 'text-text-2' },
|
||||
{ icon: '🚫', role: '게스트', desc: '접근 불가', bg: 'rgba(239,68,68,0.1)', color: 'text-red-400' },
|
||||
];
|
||||
|
||||
function formatDate(dtm: string) {
|
||||
const d = new Date(dtm);
|
||||
if (isNaN(d.getTime())) return dtm;
|
||||
return `${d.getFullYear()}-${String(d.getMonth() + 1).padStart(2, '0')}-${String(d.getDate()).padStart(2, '0')}`;
|
||||
}
|
||||
|
||||
function AssetUploadPanel() {
|
||||
const [assetCategory, setAssetCategory] = useState('전체');
|
||||
const [jurisdiction, setJurisdiction] = useState('전체');
|
||||
const [uploadMode, setUploadMode] = useState<'add' | 'replace'>('add');
|
||||
const [uploaded, setUploaded] = useState(false);
|
||||
const [uploadHistory, setUploadHistory] = useState<UploadLogItem[]>([]);
|
||||
const [dragging, setDragging] = useState(false);
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const resetTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetchUploadLogs(10)
|
||||
.then(setUploadHistory)
|
||||
.catch(err => console.error('[AssetUploadPanel] 이력 로드 실패:', err));
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (resetTimerRef.current) clearTimeout(resetTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleFileSelect = (file: File | null) => {
|
||||
if (!file) return;
|
||||
const ext = file.name.split('.').pop()?.toLowerCase();
|
||||
if (ext !== 'xlsx' && ext !== 'csv') return;
|
||||
setSelectedFile(file);
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
setDragging(false);
|
||||
const file = e.dataTransfer.files[0] ?? null;
|
||||
handleFileSelect(file);
|
||||
};
|
||||
|
||||
const handleUpload = () => {
|
||||
if (!selectedFile) return;
|
||||
setUploaded(true);
|
||||
resetTimerRef.current = setTimeout(() => {
|
||||
setUploaded(false);
|
||||
setSelectedFile(null);
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* 헤더 */}
|
||||
<div className="px-6 py-4 border-b border-border flex-shrink-0">
|
||||
<h1 className="text-lg font-bold text-text-1 font-korean">자산 현행화</h1>
|
||||
<p className="text-xs text-text-3 mt-1 font-korean">자산 데이터를 업로드하여 현행화합니다</p>
|
||||
</div>
|
||||
|
||||
{/* 본문 */}
|
||||
<div className="flex-1 overflow-auto p-6">
|
||||
<div className="flex gap-6 h-full">
|
||||
{/* 좌측: 파일 업로드 */}
|
||||
<div className="flex-1 max-w-[560px] space-y-4">
|
||||
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
|
||||
<div className="px-5 py-3 border-b border-border">
|
||||
<h2 className="text-sm font-bold text-text-1 font-korean">파일 업로드</h2>
|
||||
</div>
|
||||
<div className="px-5 py-4 space-y-4">
|
||||
{/* 드롭존 */}
|
||||
<div
|
||||
onDragOver={e => { e.preventDefault(); setDragging(true); }}
|
||||
onDragLeave={() => setDragging(false)}
|
||||
onDrop={handleDrop}
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
className={`rounded-lg border-2 border-dashed py-8 text-center cursor-pointer transition-colors ${
|
||||
dragging
|
||||
? 'border-primary-cyan bg-[rgba(6,182,212,0.05)]'
|
||||
: 'border-border hover:border-primary-cyan/50 bg-bg-2'
|
||||
}`}
|
||||
>
|
||||
<div className="text-3xl mb-2 opacity-40">📁</div>
|
||||
{selectedFile ? (
|
||||
<div className="text-xs font-semibold text-primary-cyan font-korean mb-1">{selectedFile.name}</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="text-xs font-semibold text-text-2 font-korean mb-1">파일을 드래그하거나 클릭하여 업로드</div>
|
||||
<div className="text-[10px] text-text-3 font-korean mb-3">엑셀(.xlsx), CSV 파일 지원 · 최대 10MB</div>
|
||||
<button
|
||||
type="button"
|
||||
className="px-4 py-1.5 text-xs font-semibold rounded-md bg-primary-cyan text-bg-0
|
||||
hover:shadow-[0_0_12px_rgba(6,182,212,0.3)] transition-all font-korean"
|
||||
onClick={e => { e.stopPropagation(); fileInputRef.current?.click(); }}
|
||||
>
|
||||
파일 선택
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".xlsx,.csv"
|
||||
className="hidden"
|
||||
onChange={e => handleFileSelect(e.target.files?.[0] ?? null)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 자산 분류 */}
|
||||
<div>
|
||||
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5">자산 분류</label>
|
||||
<select
|
||||
value={assetCategory}
|
||||
onChange={e => setAssetCategory(e.target.value)}
|
||||
className="w-full px-3 py-2 text-xs bg-bg-2 border border-border rounded-md
|
||||
text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
|
||||
>
|
||||
{ASSET_CATEGORIES.map(c => (
|
||||
<option key={c} value={c}>{c}</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* 대상 관할 */}
|
||||
<div>
|
||||
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5">대상 관할</label>
|
||||
<select
|
||||
value={jurisdiction}
|
||||
onChange={e => setJurisdiction(e.target.value)}
|
||||
className="w-full px-3 py-2 text-xs bg-bg-2 border border-border rounded-md
|
||||
text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
|
||||
>
|
||||
{JURISDICTIONS.map(j => (
|
||||
<option key={j} value={j}>{j}</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* 업로드 방식 */}
|
||||
<div>
|
||||
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5">업로드 방식</label>
|
||||
<div className="flex gap-4">
|
||||
<label className="flex items-center gap-1.5 cursor-pointer text-xs text-text-2 font-korean">
|
||||
<input
|
||||
type="radio"
|
||||
checked={uploadMode === 'add'}
|
||||
onChange={() => setUploadMode('add')}
|
||||
className="accent-primary-cyan"
|
||||
/>
|
||||
추가 (기존 + 신규)
|
||||
</label>
|
||||
<label className="flex items-center gap-1.5 cursor-pointer text-xs text-text-2 font-korean">
|
||||
<input
|
||||
type="radio"
|
||||
checked={uploadMode === 'replace'}
|
||||
onChange={() => setUploadMode('replace')}
|
||||
className="accent-primary-cyan"
|
||||
/>
|
||||
덮어쓰기 (전체 교체)
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 업로드 버튼 */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleUpload}
|
||||
disabled={!selectedFile || uploaded}
|
||||
className={`w-full py-2.5 text-xs font-semibold rounded-md transition-all font-korean disabled:opacity-50 ${
|
||||
uploaded
|
||||
? 'bg-[rgba(34,197,94,0.15)] text-status-green border border-status-green/30'
|
||||
: 'bg-primary-cyan text-bg-0 hover:shadow-[0_0_12px_rgba(6,182,212,0.3)]'
|
||||
}`}
|
||||
>
|
||||
{uploaded ? '✅ 업로드 완료!' : '📤 업로드 실행'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 우측 */}
|
||||
<div className="w-[400px] space-y-4 flex-shrink-0">
|
||||
{/* 수정 권한 체계 */}
|
||||
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
|
||||
<div className="px-5 py-3 border-b border-border">
|
||||
<h2 className="text-sm font-bold text-text-1 font-korean">수정 권한 체계</h2>
|
||||
</div>
|
||||
<div className="px-5 py-4 space-y-2">
|
||||
{PERM_ITEMS.map(p => (
|
||||
<div
|
||||
key={p.role}
|
||||
className="flex items-center gap-3 px-4 py-3 bg-bg-2 border border-border rounded-md"
|
||||
>
|
||||
<div
|
||||
className="w-8 h-8 rounded-full flex items-center justify-center text-sm flex-shrink-0"
|
||||
style={{ background: p.bg }}
|
||||
>
|
||||
{p.icon}
|
||||
</div>
|
||||
<div>
|
||||
<div className={`text-xs font-bold font-korean ${p.color}`}>{p.role}</div>
|
||||
<div className="text-[10px] text-text-3 font-korean mt-0.5">{p.desc}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 최근 업로드 이력 */}
|
||||
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
|
||||
<div className="px-5 py-3 border-b border-border">
|
||||
<h2 className="text-sm font-bold text-text-1 font-korean">최근 업로드 이력</h2>
|
||||
</div>
|
||||
<div className="px-5 py-4 space-y-2">
|
||||
{uploadHistory.length === 0 ? (
|
||||
<div className="text-[11px] text-text-3 font-korean text-center py-4">이력이 없습니다.</div>
|
||||
) : uploadHistory.map(h => (
|
||||
<div
|
||||
key={h.logSn}
|
||||
className="flex justify-between items-center px-4 py-3 bg-bg-2 border border-border rounded-md"
|
||||
>
|
||||
<div>
|
||||
<div className="text-xs font-semibold text-text-1 font-korean">{h.fileNm}</div>
|
||||
<div className="text-[10px] text-text-3 mt-0.5 font-korean">
|
||||
{formatDate(h.regDtm)} · {h.uploaderNm} · {h.uploadCnt.toLocaleString()}건
|
||||
</div>
|
||||
</div>
|
||||
<span className="px-2 py-0.5 rounded-full text-[10px] font-semibold
|
||||
bg-[rgba(34,197,94,0.15)] text-status-green flex-shrink-0">
|
||||
완료
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default AssetUploadPanel;
|
||||
230
frontend/src/tabs/admin/components/CleanupEquipPanel.tsx
Normal file
230
frontend/src/tabs/admin/components/CleanupEquipPanel.tsx
Normal file
@ -0,0 +1,230 @@
|
||||
import { useState, useEffect, useMemo } from 'react';
|
||||
import { fetchOrganizations } from '@tabs/assets/services/assetsApi';
|
||||
import type { AssetOrgCompat } from '@tabs/assets/services/assetsApi';
|
||||
import { typeTagCls } from '@tabs/assets/components/assetTypes';
|
||||
|
||||
const PAGE_SIZE = 20;
|
||||
|
||||
const regionShort = (j: string) =>
|
||||
j.includes('남해') ? '남해청' : j.includes('서해') ? '서해청' :
|
||||
j.includes('중부') ? '중부청' : j.includes('동해') ? '동해청' :
|
||||
j.includes('제주') ? '제주청' : j;
|
||||
|
||||
function CleanupEquipPanel() {
|
||||
const [organizations, setOrganizations] = useState<AssetOrgCompat[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [regionFilter, setRegionFilter] = useState('전체');
|
||||
const [typeFilter, setTypeFilter] = useState('전체');
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
|
||||
const load = () => {
|
||||
setLoading(true);
|
||||
fetchOrganizations()
|
||||
.then(setOrganizations)
|
||||
.catch(err => console.error('[CleanupEquipPanel] 데이터 로드 실패:', err))
|
||||
.finally(() => setLoading(false));
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
fetchOrganizations()
|
||||
.then(data => { if (!cancelled) setOrganizations(data); })
|
||||
.catch(err => console.error('[CleanupEquipPanel] 데이터 로드 실패:', err))
|
||||
.finally(() => { if (!cancelled) setLoading(false); });
|
||||
return () => { cancelled = true; };
|
||||
}, []);
|
||||
|
||||
const typeOptions = useMemo(() => {
|
||||
const set = new Set(organizations.map(o => o.type));
|
||||
return Array.from(set).sort();
|
||||
}, [organizations]);
|
||||
|
||||
const filtered = useMemo(() =>
|
||||
organizations
|
||||
.filter(o => regionFilter === '전체' || o.jurisdiction.includes(regionFilter))
|
||||
.filter(o => typeFilter === '전체' || o.type === typeFilter)
|
||||
.filter(o => !searchTerm || o.name.includes(searchTerm) || o.address.includes(searchTerm)),
|
||||
[organizations, regionFilter, typeFilter, searchTerm]
|
||||
);
|
||||
|
||||
const totalPages = Math.max(1, Math.ceil(filtered.length / PAGE_SIZE));
|
||||
const safePage = Math.min(currentPage, totalPages);
|
||||
const paged = filtered.slice((safePage - 1) * PAGE_SIZE, safePage * PAGE_SIZE);
|
||||
|
||||
const handleFilterChange = (setter: (v: string) => void) => (e: React.ChangeEvent<HTMLSelectElement>) => {
|
||||
setter(e.target.value);
|
||||
setCurrentPage(1);
|
||||
};
|
||||
|
||||
const pageNumbers = (() => {
|
||||
const range: number[] = [];
|
||||
const start = Math.max(1, safePage - 2);
|
||||
const end = Math.min(totalPages, safePage + 2);
|
||||
for (let i = start; i <= end; i++) range.push(i);
|
||||
return range;
|
||||
})();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* 헤더 */}
|
||||
<div className="flex items-center justify-between px-6 py-4 border-b border-border">
|
||||
<div>
|
||||
<h1 className="text-lg font-bold text-text-1 font-korean">방제장비 현황</h1>
|
||||
<p className="text-xs text-text-3 mt-1 font-korean">총 {filtered.length}개 기관</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<select
|
||||
value={regionFilter}
|
||||
onChange={handleFilterChange(setRegionFilter)}
|
||||
className="px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
|
||||
>
|
||||
<option value="전체">전체 관할청</option>
|
||||
<option value="남해">남해청</option>
|
||||
<option value="서해">서해청</option>
|
||||
<option value="중부">중부청</option>
|
||||
<option value="동해">동해청</option>
|
||||
<option value="제주">제주청</option>
|
||||
</select>
|
||||
<select
|
||||
value={typeFilter}
|
||||
onChange={handleFilterChange(setTypeFilter)}
|
||||
className="px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
|
||||
>
|
||||
<option value="전체">전체 유형</option>
|
||||
{typeOptions.map(t => (
|
||||
<option key={t} value={t}>{t}</option>
|
||||
))}
|
||||
</select>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="기관명, 주소 검색..."
|
||||
value={searchTerm}
|
||||
onChange={e => { setSearchTerm(e.target.value); setCurrentPage(1); }}
|
||||
className="w-56 px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 placeholder-text-3 focus:border-primary-cyan focus:outline-none font-korean"
|
||||
/>
|
||||
<button
|
||||
onClick={load}
|
||||
className="px-4 py-2 text-xs font-semibold rounded-md bg-bg-2 border border-border text-text-2 hover:border-primary-cyan hover:text-primary-cyan transition-all font-korean"
|
||||
>
|
||||
새로고침
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 테이블 */}
|
||||
<div className="flex-1 overflow-auto">
|
||||
{loading ? (
|
||||
<div className="flex items-center justify-center h-32 text-text-3 text-sm font-korean">
|
||||
불러오는 중...
|
||||
</div>
|
||||
) : (
|
||||
<table className="w-full">
|
||||
<thead>
|
||||
<tr className="border-b border-border bg-bg-1">
|
||||
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean w-10">번호</th>
|
||||
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean">유형</th>
|
||||
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean">관할청</th>
|
||||
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean">기관명</th>
|
||||
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean">주소</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">방제선</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">유회수기</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">이송펌프</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">방제차량</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">살포장치</th>
|
||||
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean">총자산</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{paged.length === 0 ? (
|
||||
<tr>
|
||||
<td colSpan={11} className="px-6 py-10 text-center text-xs text-text-3 font-korean">
|
||||
조회된 기관이 없습니다.
|
||||
</td>
|
||||
</tr>
|
||||
) : paged.map((org, idx) => (
|
||||
<tr key={org.id} className="border-b border-border hover:bg-[rgba(255,255,255,0.02)] transition-colors">
|
||||
<td className="px-4 py-3 text-[11px] text-text-3 font-mono text-center">
|
||||
{(safePage - 1) * PAGE_SIZE + idx + 1}
|
||||
</td>
|
||||
<td className="px-4 py-3">
|
||||
<span className={`text-[10px] px-1.5 py-0.5 rounded font-bold font-korean ${typeTagCls(org.type)}`}>
|
||||
{org.type}
|
||||
</span>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] text-text-2 font-korean">
|
||||
{regionShort(org.jurisdiction)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] text-text-1 font-korean font-semibold">
|
||||
{org.name}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] text-text-3 font-korean max-w-[200px] truncate">
|
||||
{org.address}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
|
||||
{org.vessel > 0 ? <span className="text-text-1">{org.vessel}</span> : <span className="text-text-3">—</span>}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
|
||||
{org.skimmer > 0 ? <span className="text-text-1">{org.skimmer}</span> : <span className="text-text-3">—</span>}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
|
||||
{org.pump > 0 ? <span className="text-text-1">{org.pump}</span> : <span className="text-text-3">—</span>}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
|
||||
{org.vehicle > 0 ? <span className="text-text-1">{org.vehicle}</span> : <span className="text-text-3">—</span>}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
|
||||
{org.sprayer > 0 ? <span className="text-text-1">{org.sprayer}</span> : <span className="text-text-3">—</span>}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-[11px] font-mono text-center font-bold text-primary-cyan">
|
||||
{org.totalAssets.toLocaleString()}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 페이지네이션 */}
|
||||
{!loading && filtered.length > 0 && (
|
||||
<div className="flex items-center justify-between px-6 py-3 border-t border-border">
|
||||
<span className="text-[11px] text-text-3 font-korean">
|
||||
{(safePage - 1) * PAGE_SIZE + 1}–{Math.min(safePage * PAGE_SIZE, filtered.length)} / 전체 {filtered.length}개
|
||||
</span>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={() => setCurrentPage(p => Math.max(1, p - 1))}
|
||||
disabled={safePage === 1}
|
||||
className="px-2.5 py-1 text-[11px] border border-border rounded text-text-2 hover:border-primary-cyan hover:text-primary-cyan disabled:opacity-40 transition-colors"
|
||||
>
|
||||
<
|
||||
</button>
|
||||
{pageNumbers.map(p => (
|
||||
<button
|
||||
key={p}
|
||||
onClick={() => setCurrentPage(p)}
|
||||
className="px-2.5 py-1 text-[11px] border rounded transition-colors"
|
||||
style={p === safePage
|
||||
? { borderColor: 'var(--cyan)', color: 'var(--cyan)', background: 'rgba(6,182,212,0.1)' }
|
||||
: { borderColor: 'var(--border)', color: 'var(--text-2)' }
|
||||
}
|
||||
>
|
||||
{p}
|
||||
</button>
|
||||
))}
|
||||
<button
|
||||
onClick={() => setCurrentPage(p => Math.min(totalPages, p + 1))}
|
||||
disabled={safePage === totalPages}
|
||||
className="px-2.5 py-1 text-[11px] border border-border rounded text-text-2 hover:border-primary-cyan hover:text-primary-cyan disabled:opacity-40 transition-colors"
|
||||
>
|
||||
>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default CleanupEquipPanel;
|
||||
@ -294,6 +294,7 @@ interface RolePermTabProps {
|
||||
setSelectedRoleSn: (sn: number | null) => void
|
||||
dirty: boolean
|
||||
saving: boolean
|
||||
saveError: string | null
|
||||
handleSave: () => Promise<void>
|
||||
handleToggleExpand: (code: string) => void
|
||||
handleTogglePerm: (code: string, oper: OperCode, currentState: PermState) => void
|
||||
@ -328,6 +329,7 @@ function RolePermTab({
|
||||
setSelectedRoleSn,
|
||||
dirty,
|
||||
saving,
|
||||
saveError,
|
||||
handleSave,
|
||||
handleToggleExpand,
|
||||
handleTogglePerm,
|
||||
@ -378,6 +380,9 @@ function RolePermTab({
|
||||
>
|
||||
{saving ? '저장 중...' : '변경사항 저장'}
|
||||
</button>
|
||||
{saveError && (
|
||||
<span className="text-[11px] text-status-red font-korean">{saveError}</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 역할 탭 바 */}
|
||||
@ -861,6 +866,7 @@ function PermissionsPanel() {
|
||||
const [permTree, setPermTree] = useState<PermTreeNode[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [saveError, setSaveError] = useState<string | null>(null)
|
||||
const [dirty, setDirty] = useState(false)
|
||||
const [showCreateForm, setShowCreateForm] = useState(false)
|
||||
const [newRoleCode, setNewRoleCode] = useState('')
|
||||
@ -962,6 +968,7 @@ function PermissionsPanel() {
|
||||
|
||||
const handleSave = async () => {
|
||||
setSaving(true)
|
||||
setSaveError(null)
|
||||
try {
|
||||
for (const role of roles) {
|
||||
const perms = rolePerms.get(role.sn)
|
||||
@ -981,6 +988,7 @@ function PermissionsPanel() {
|
||||
setDirty(false)
|
||||
} catch (err) {
|
||||
console.error('권한 저장 실패:', err)
|
||||
setSaveError('권한 저장에 실패했습니다. 다시 시도해주세요.')
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
@ -1096,6 +1104,7 @@ function PermissionsPanel() {
|
||||
setSelectedRoleSn={setSelectedRoleSn}
|
||||
dirty={dirty}
|
||||
saving={saving}
|
||||
saveError={saveError}
|
||||
handleSave={handleSave}
|
||||
handleToggleExpand={handleToggleExpand}
|
||||
handleTogglePerm={handleTogglePerm}
|
||||
|
||||
@ -51,6 +51,7 @@ export const ADMIN_MENU: AdminMenuItem[] = [
|
||||
id: 'coast-guard-assets', label: '해경자산',
|
||||
children: [
|
||||
{ id: 'cleanup-equip', label: '방제장비' },
|
||||
{ id: 'asset-upload', label: '자산현행화' },
|
||||
{ id: 'dispersant-zone', label: '유처리제 제한구역' },
|
||||
{ id: 'vessel-materials', label: '방제선 보유자재' },
|
||||
{ id: 'cleanup-resource', label: '방제자원' },
|
||||
|
||||
@ -13,9 +13,7 @@ import type { BoomLine, AlgorithmSettings, ContainmentResult, BoomLineCoord } fr
|
||||
import type { BacktrackPhase, BacktrackVessel, BacktrackConditions, ReplayShip, CollisionEvent } from '@common/types/backtrack'
|
||||
import { TOTAL_REPLAY_FRAMES } from '@common/types/backtrack'
|
||||
import { fetchBacktrackByAcdnt, createBacktrack, fetchPredictionDetail, fetchAnalysisTrajectory } from '../services/predictionApi'
|
||||
import type { CenterPoint, HydrDataStep, ImageAnalyzeResult, OilParticle, PredictionDetail, SimulationRunResponse, SimulationSummary, WindPoint } from '../services/predictionApi'
|
||||
import { useMultiSimulationStatus } from '../hooks/useSimulationStatus'
|
||||
import type { ModelExecRef } from '../hooks/useSimulationStatus'
|
||||
import type { CenterPoint, HydrDataStep, ImageAnalyzeResult, OilParticle, PredictionDetail, RunModelSyncResponse, SimulationSummary, WindPoint } from '../services/predictionApi'
|
||||
import SimulationLoadingOverlay from './SimulationLoadingOverlay'
|
||||
import SimulationErrorModal from './SimulationErrorModal'
|
||||
import { api } from '@common/services/api'
|
||||
@ -124,6 +122,8 @@ export function OilSpillView() {
|
||||
const [hydrDataByModel, setHydrDataByModel] = useState<Record<string, (HydrDataStep | null)[]>>({})
|
||||
const [windHydrModel, setWindHydrModel] = useState<string>('OpenDrift')
|
||||
const [isRunningSimulation, setIsRunningSimulation] = useState(false)
|
||||
const [simulationProgress, setSimulationProgress] = useState(0)
|
||||
const progressTimerRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||
const [simulationError, setSimulationError] = useState<string | null>(null)
|
||||
const [selectedModels, setSelectedModels] = useState<Set<PredictionModel>>(new Set(['OpenDrift']))
|
||||
const [visibleModels, setVisibleModels] = useState<Set<PredictionModel>>(new Set(['OpenDrift']))
|
||||
@ -191,9 +191,8 @@ export function OilSpillView() {
|
||||
|
||||
// 재계산 상태
|
||||
const [recalcModalOpen, setRecalcModalOpen] = useState(false)
|
||||
const [pendingExecSns, setPendingExecSns] = useState<ModelExecRef[]>([])
|
||||
const [simulationSummary, setSimulationSummary] = useState<SimulationSummary | null>(null)
|
||||
const { allDone: simAllDone, anyError: simAnyError, results: simResults, errors: simErrors } = useMultiSimulationStatus(pendingExecSns)
|
||||
const [summaryByModel, setSummaryByModel] = useState<Record<string, SimulationSummary>>({})
|
||||
|
||||
// 오염분석 상태
|
||||
const [analysisTab, setAnalysisTab] = useState<'polygon' | 'circle'>('polygon')
|
||||
@ -392,91 +391,30 @@ export function OilSpillView() {
|
||||
}
|
||||
}, [])
|
||||
|
||||
// 시뮬레이션 폴링 결과 처리 (다중 모델)
|
||||
useEffect(() => {
|
||||
if (pendingExecSns.length === 0) return;
|
||||
|
||||
if (simAllDone) {
|
||||
// 모든 모델의 trajectory 병합 (model 필드 포함)
|
||||
const merged: OilParticle[] = [];
|
||||
let latestSummary: SimulationSummary | null = null;
|
||||
let latestCenterPoints: CenterPoint[] = [];
|
||||
const newWindDataByModel: Record<string, WindPoint[][]> = {};
|
||||
const newHydrDataByModel: Record<string, (HydrDataStep | null)[]> = {};
|
||||
|
||||
simResults.forEach((statusData, model) => {
|
||||
if (statusData.trajectory) {
|
||||
const withModel = statusData.trajectory.map(p => ({ ...p, model }));
|
||||
merged.push(...withModel);
|
||||
}
|
||||
// summary는 OpenDrift 우선, 없으면 다른 모델
|
||||
if (model === 'OpenDrift' || !latestSummary) {
|
||||
if (statusData.summary) latestSummary = statusData.summary;
|
||||
}
|
||||
// windData/hydrData는 모델별로 저장
|
||||
if (statusData.windData) newWindDataByModel[model] = statusData.windData;
|
||||
if (statusData.hydrData) newHydrDataByModel[model] = statusData.hydrData;
|
||||
// centerPoints는 모든 모델 누적 (model 필드 포함 보장)
|
||||
if (statusData.centerPoints) {
|
||||
const withModel = statusData.centerPoints.map(p => ({ ...p, model }));
|
||||
latestCenterPoints = [...latestCenterPoints, ...withModel];
|
||||
}
|
||||
});
|
||||
|
||||
if (merged.length > 0) {
|
||||
setOilTrajectory(merged);
|
||||
const doneModels = new Set<PredictionModel>(
|
||||
Array.from(simResults.entries())
|
||||
.filter(([, s]) => s.trajectory && s.trajectory.length > 0)
|
||||
.map(([m]) => m as PredictionModel)
|
||||
)
|
||||
setVisibleModels(doneModels)
|
||||
setSimulationSummary(latestSummary);
|
||||
setCenterPoints(latestCenterPoints);
|
||||
|
||||
// 데이터가 없는 모델에 OpenDrift(또는 첫 번째 보유 모델) 데이터 복사
|
||||
const refWindData = newWindDataByModel['OpenDrift'] ?? Object.values(newWindDataByModel)[0];
|
||||
const refHydrData = newHydrDataByModel['OpenDrift'] ?? Object.values(newHydrDataByModel)[0];
|
||||
doneModels.forEach(model => {
|
||||
if (!newWindDataByModel[model] && refWindData) newWindDataByModel[model] = refWindData;
|
||||
if (!newHydrDataByModel[model] && refHydrData) newHydrDataByModel[model] = refHydrData;
|
||||
});
|
||||
|
||||
setWindDataByModel(newWindDataByModel);
|
||||
setHydrDataByModel(newHydrDataByModel);
|
||||
setWindHydrModel('OpenDrift');
|
||||
if (incidentCoord) {
|
||||
const booms = generateAIBoomLines(merged, incidentCoord, algorithmSettings);
|
||||
setBoomLines(booms);
|
||||
}
|
||||
setSensitiveResources(DEMO_SENSITIVE_RESOURCES);
|
||||
setCurrentStep(0);
|
||||
setIsPlaying(true);
|
||||
if (incidentCoord) {
|
||||
setFlyToCoord({ lon: incidentCoord.lon, lat: incidentCoord.lat });
|
||||
}
|
||||
}
|
||||
setIsRunningSimulation(false);
|
||||
setPendingExecSns([]);
|
||||
}
|
||||
|
||||
if (simAnyError) {
|
||||
setIsRunningSimulation(false);
|
||||
setPendingExecSns([]);
|
||||
const errorMessages = Array.from(simErrors.values()).join('; ');
|
||||
setSimulationError(errorMessages || '시뮬레이션 처리 중 오류가 발생했습니다.');
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [simAllDone, simAnyError, simResults, simErrors, pendingExecSns.length, incidentCoord, algorithmSettings]);
|
||||
|
||||
// trajectory 변경 시 플레이어 스텝 초기화 (재생은 각 경로에서 별도 처리)
|
||||
useEffect(() => {
|
||||
if (oilTrajectory.length > 0) {
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setCurrentStep(0);
|
||||
}
|
||||
}, [oilTrajectory.length]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (progressTimerRef.current) clearInterval(progressTimerRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// windHydrModel이 visibleModels에 없으면 자동으로 적절한 모델로 전환
|
||||
useEffect(() => {
|
||||
if (visibleModels.size === 0) return;
|
||||
if (!visibleModels.has(windHydrModel as PredictionModel)) {
|
||||
const preferred: PredictionModel[] = ['OpenDrift', 'POSEIDON', 'KOSPS'];
|
||||
const next = preferred.find(m => visibleModels.has(m)) ?? Array.from(visibleModels)[0];
|
||||
setWindHydrModel(next);
|
||||
}
|
||||
}, [visibleModels, windHydrModel]);
|
||||
|
||||
// 플레이어 재생 애니메이션 (1x = 1초/스텝, 2x = 0.5초/스텝, 4x = 0.25초/스텝)
|
||||
const timeSteps = useMemo(() => {
|
||||
if (oilTrajectory.length === 0) return [];
|
||||
@ -500,7 +438,6 @@ export function OilSpillView() {
|
||||
useEffect(() => {
|
||||
if (!isPlaying || timeSteps.length === 0) return;
|
||||
if (currentStep >= maxTime) {
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setIsPlaying(false);
|
||||
return;
|
||||
}
|
||||
@ -560,17 +497,19 @@ export function OilSpillView() {
|
||||
: incidentCoord
|
||||
const demoModels = Array.from(models.size > 0 ? models : new Set<PredictionModel>(['KOSPS']))
|
||||
|
||||
// OpenDrift 완료된 경우 실제 궤적 로드, 없으면 데모로 fallback
|
||||
if (analysis.opendriftStatus === 'completed') {
|
||||
// 완료된 모델이 있는 경우 실제 궤적 로드, 없으면 데모로 fallback
|
||||
const hasCompletedModel =
|
||||
analysis.opendriftStatus === 'completed' || analysis.poseidonStatus === 'completed';
|
||||
if (hasCompletedModel) {
|
||||
try {
|
||||
const { trajectory, summary, centerPoints: cp, windData: wd, hydrData: hd } = await fetchAnalysisTrajectory(analysis.acdntSn)
|
||||
const { trajectory, summary, centerPoints: cp, windDataByModel: wdByModel, hydrDataByModel: hdByModel, summaryByModel: sbModel } = await fetchAnalysisTrajectory(analysis.acdntSn)
|
||||
if (trajectory && trajectory.length > 0) {
|
||||
setOilTrajectory(trajectory)
|
||||
if (summary) setSimulationSummary(summary)
|
||||
setCenterPoints(cp ?? [])
|
||||
setWindDataByModel(wd && wd.length > 0 ? { 'OpenDrift': wd } : {})
|
||||
setHydrDataByModel(hd && hd.length > 0 ? { 'OpenDrift': hd } : {})
|
||||
setWindHydrModel('OpenDrift')
|
||||
setWindDataByModel(wdByModel ?? {});
|
||||
setHydrDataByModel(hdByModel ?? {});
|
||||
if (sbModel) setSummaryByModel(sbModel);
|
||||
if (coord) setBoomLines(generateAIBoomLines(trajectory, coord, algorithmSettings))
|
||||
setSensitiveResources(DEMO_SENSITIVE_RESOURCES)
|
||||
// incidentCoord가 변경된 경우 flyTo 완료 후 재생, 그렇지 않으면 즉시 재생
|
||||
@ -586,7 +525,10 @@ export function OilSpillView() {
|
||||
}
|
||||
}
|
||||
|
||||
// 데모 궤적 생성 (fallback)
|
||||
// 데모 궤적 생성 (fallback) — stale wind/current 데이터 초기화
|
||||
setWindDataByModel({})
|
||||
setHydrDataByModel({})
|
||||
setSummaryByModel({})
|
||||
const demoTrajectory = generateDemoTrajectory(coord ?? { lat: 37.39, lon: 126.64 }, demoModels, parseInt(analysis.duration) || 48)
|
||||
setOilTrajectory(demoTrajectory)
|
||||
if (coord) setBoomLines(generateAIBoomLines(demoTrajectory, coord, algorithmSettings))
|
||||
@ -690,50 +632,81 @@ export function OilSpillView() {
|
||||
})
|
||||
}, [])
|
||||
|
||||
const handleRunSimulation = async () => {
|
||||
const startProgressTimer = useCallback((runTimeHours: number) => {
|
||||
const expectedMs = runTimeHours * 6000;
|
||||
const startTime = Date.now();
|
||||
progressTimerRef.current = setInterval(() => {
|
||||
const elapsed = Date.now() - startTime;
|
||||
setSimulationProgress(Math.min(90, Math.round((elapsed / expectedMs) * 90)));
|
||||
}, 500);
|
||||
}, []);
|
||||
|
||||
const stopProgressTimer = useCallback((completed: boolean) => {
|
||||
if (progressTimerRef.current) {
|
||||
clearInterval(progressTimerRef.current);
|
||||
progressTimerRef.current = null;
|
||||
}
|
||||
if (completed) {
|
||||
setSimulationProgress(100);
|
||||
setTimeout(() => setSimulationProgress(0), 800);
|
||||
} else {
|
||||
setSimulationProgress(0);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleRunSimulation = async (overrides?: {
|
||||
models?: Set<PredictionModel>;
|
||||
oilType?: string;
|
||||
spillAmount?: number;
|
||||
spillType?: string;
|
||||
predictionTime?: number;
|
||||
incidentCoord?: { lat: number; lon: number } | null;
|
||||
}) => {
|
||||
// incidentName이 있으면 직접 입력 모드 — 기존 selectedAnalysis.acdntSn 무시하고 새 사고 생성
|
||||
const isDirectInput = incidentName.trim().length > 0;
|
||||
const existingAcdntSn = isDirectInput
|
||||
? undefined
|
||||
: (selectedAnalysis?.acdntSn ?? analysisDetail?.acdnt?.acdntSn);
|
||||
|
||||
// 선택 모드인데 사고도 없으면 실행 불가, 직접 입력 모드인데 사고명 없으면 실행 불가
|
||||
if (!isDirectInput && !existingAcdntSn) {
|
||||
return;
|
||||
}
|
||||
if (!incidentCoord) {
|
||||
return;
|
||||
}
|
||||
const effectiveCoord = overrides?.incidentCoord ?? incidentCoord;
|
||||
if (!isDirectInput && !existingAcdntSn) return;
|
||||
if (!effectiveCoord) return;
|
||||
|
||||
const effectiveOilType = overrides?.oilType ?? oilType;
|
||||
const effectiveSpillAmount = overrides?.spillAmount ?? spillAmount;
|
||||
const effectiveSpillType = overrides?.spillType ?? spillType;
|
||||
const effectivePredictionTime = overrides?.predictionTime ?? predictionTime;
|
||||
const effectiveModels = overrides?.models ?? selectedModels;
|
||||
|
||||
setIsRunningSimulation(true);
|
||||
setSimulationSummary(null);
|
||||
startProgressTimer(effectivePredictionTime);
|
||||
let simulationSucceeded = false;
|
||||
try {
|
||||
const payload: Record<string, unknown> = {
|
||||
acdntSn: existingAcdntSn,
|
||||
lat: incidentCoord.lat,
|
||||
lon: incidentCoord.lon,
|
||||
runTime: predictionTime,
|
||||
matTy: oilType,
|
||||
matVol: spillAmount,
|
||||
spillTime: spillType === '연속' ? predictionTime : 0,
|
||||
lat: effectiveCoord.lat,
|
||||
lon: effectiveCoord.lon,
|
||||
runTime: effectivePredictionTime,
|
||||
matTy: effectiveOilType,
|
||||
matVol: effectiveSpillAmount,
|
||||
spillTime: effectiveSpillType === '연속' ? effectivePredictionTime : 0,
|
||||
startTime: accidentTime
|
||||
? `${accidentTime}:00`
|
||||
: analysisDetail?.acdnt?.occurredAt,
|
||||
models: Array.from(effectiveModels),
|
||||
};
|
||||
|
||||
// 직접 입력 모드: 백엔드에서 ACDNT + SPIL_DATA 생성에 필요한 필드 추가
|
||||
if (isDirectInput) {
|
||||
payload.acdntNm = incidentName.trim();
|
||||
payload.spillUnit = spillUnit;
|
||||
payload.spillTypeCd = spillType;
|
||||
}
|
||||
|
||||
payload.models = Array.from(selectedModels);
|
||||
|
||||
const { data } = await api.post<SimulationRunResponse>('/simulation/run', payload);
|
||||
setPendingExecSns(
|
||||
data.execSns ?? (data.execSn ? [{ model: 'OpenDrift', execSn: data.execSn }] : [])
|
||||
);
|
||||
// 동기 방식: 예측 완료 시 결과를 직접 반환 (최대 35분 대기)
|
||||
const { data } = await api.post<RunModelSyncResponse>('/simulation/run-model', payload, {
|
||||
timeout: 35 * 60 * 1000,
|
||||
});
|
||||
|
||||
// 직접 입력으로 신규 생성된 경우: selectedAnalysis 갱신 + incidentName 초기화
|
||||
if (data.acdntSn && isDirectInput) {
|
||||
@ -747,8 +720,8 @@ export function OilSpillView() {
|
||||
oilType,
|
||||
volume: spillAmount,
|
||||
location: '',
|
||||
lat: incidentCoord.lat,
|
||||
lon: incidentCoord.lon,
|
||||
lat: effectiveCoord.lat,
|
||||
lon: effectiveCoord.lon,
|
||||
kospsStatus: 'pending',
|
||||
poseidonStatus: 'pending',
|
||||
opendriftStatus: 'pending',
|
||||
@ -756,16 +729,79 @@ export function OilSpillView() {
|
||||
analyst: '',
|
||||
officeName: '',
|
||||
} as Analysis);
|
||||
// 다음 실행 시 동일 사고 재생성 방지 — 이후에는 selectedAnalysis.acdntSn 사용
|
||||
setIncidentName('');
|
||||
}
|
||||
// setIsRunningSimulation(false)는 폴링 결과 useEffect에서 처리
|
||||
|
||||
// 결과 처리
|
||||
const merged: OilParticle[] = [];
|
||||
let latestSummary: SimulationSummary | null = null;
|
||||
let latestCenterPoints: CenterPoint[] = [];
|
||||
const newWindDataByModel: Record<string, WindPoint[][]> = {};
|
||||
const newHydrDataByModel: Record<string, (HydrDataStep | null)[]> = {};
|
||||
const newSummaryByModel: Record<string, SimulationSummary> = {};
|
||||
const errors: string[] = [];
|
||||
|
||||
data.results.forEach(({ model, status, trajectory, summary, centerPoints, windData, hydrData, error }) => {
|
||||
if (status === 'ERROR') {
|
||||
errors.push(error || `${model} 분석 중 오류가 발생했습니다.`);
|
||||
return;
|
||||
}
|
||||
if (trajectory) {
|
||||
merged.push(...trajectory.map(p => ({ ...p, model })));
|
||||
}
|
||||
if (summary) {
|
||||
newSummaryByModel[model] = summary;
|
||||
if (model === 'OpenDrift' || !latestSummary) latestSummary = summary;
|
||||
}
|
||||
if (windData) newWindDataByModel[model] = windData;
|
||||
if (hydrData) newHydrDataByModel[model] = hydrData;
|
||||
if (centerPoints) {
|
||||
latestCenterPoints = [...latestCenterPoints, ...centerPoints.map(p => ({ ...p, model }))];
|
||||
}
|
||||
});
|
||||
|
||||
if (merged.length > 0) {
|
||||
setOilTrajectory(merged);
|
||||
const doneModels = new Set<PredictionModel>(
|
||||
data.results
|
||||
.filter(r => r.status === 'DONE' && r.trajectory && r.trajectory.length > 0)
|
||||
.map(r => r.model as PredictionModel)
|
||||
);
|
||||
setVisibleModels(doneModels);
|
||||
setSimulationSummary(latestSummary);
|
||||
setCenterPoints(latestCenterPoints);
|
||||
|
||||
const refWindData = newWindDataByModel['OpenDrift'] ?? Object.values(newWindDataByModel)[0];
|
||||
const refHydrData = newHydrDataByModel['OpenDrift'] ?? Object.values(newHydrDataByModel)[0];
|
||||
doneModels.forEach(model => {
|
||||
if (!newWindDataByModel[model] && refWindData) newWindDataByModel[model] = refWindData;
|
||||
if (!newHydrDataByModel[model] && refHydrData) newHydrDataByModel[model] = refHydrData;
|
||||
});
|
||||
|
||||
setWindDataByModel(newWindDataByModel);
|
||||
setHydrDataByModel(newHydrDataByModel);
|
||||
setSummaryByModel(newSummaryByModel);
|
||||
const booms = generateAIBoomLines(merged, effectiveCoord, algorithmSettings);
|
||||
setBoomLines(booms);
|
||||
setSensitiveResources(DEMO_SENSITIVE_RESOURCES);
|
||||
setCurrentStep(0);
|
||||
setIsPlaying(true);
|
||||
setFlyToCoord({ lon: effectiveCoord.lon, lat: effectiveCoord.lat });
|
||||
}
|
||||
|
||||
if (errors.length > 0 && merged.length === 0) {
|
||||
setSimulationError(errors.join('; '));
|
||||
} else {
|
||||
simulationSucceeded = true;
|
||||
}
|
||||
} catch (err) {
|
||||
setIsRunningSimulation(false);
|
||||
const msg =
|
||||
(err as { message?: string })?.message
|
||||
?? '시뮬레이션 실행 중 오류가 발생했습니다.';
|
||||
setSimulationError(msg);
|
||||
} finally {
|
||||
stopProgressTimer(simulationSucceeded);
|
||||
setIsRunningSimulation(false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -808,7 +844,13 @@ export function OilSpillView() {
|
||||
weather: wx
|
||||
? { windDir: wx.wind, windSpeed: wx.wind, waveHeight: wx.wave, temp: wx.temp }
|
||||
: null,
|
||||
spread: { kosps: '—', openDrift: '—', poseidon: '—' },
|
||||
spread: (() => {
|
||||
const fmt = (model: string) => {
|
||||
const s = summaryByModel[model];
|
||||
return s ? `${s.pollutionArea.toFixed(2)} km²` : '—';
|
||||
};
|
||||
return { kosps: fmt('KOSPS'), openDrift: fmt('OpenDrift'), poseidon: fmt('POSEIDON') };
|
||||
})(),
|
||||
coastal: {
|
||||
firstTime: (() => {
|
||||
const beachedTimes = oilTrajectory.filter(p => p.stranded === 1).map(p => p.time);
|
||||
@ -1077,7 +1119,8 @@ export function OilSpillView() {
|
||||
<div style={{ display: 'flex', flexDirection: 'column', alignItems: 'flex-end', gap: '4px', flexShrink: 0, minWidth: '200px' }}>
|
||||
<div style={{ fontSize: '14px', fontWeight: 600, color: 'var(--cyan)', fontFamily: 'var(--fM)' }}>
|
||||
+{currentStep}h — {(() => {
|
||||
const d = new Date(); d.setHours(d.getHours() + currentStep);
|
||||
const base = accidentTime ? new Date(accidentTime) : new Date();
|
||||
const d = new Date(base.getTime() + currentStep * 3600 * 1000);
|
||||
return `${String(d.getMonth() + 1).padStart(2, '0')}/${String(d.getDate()).padStart(2, '0')} ${String(d.getHours()).padStart(2, '0')}:${String(d.getMinutes()).padStart(2, '0')} KST`;
|
||||
})()}
|
||||
</div>
|
||||
@ -1150,7 +1193,7 @@ export function OilSpillView() {
|
||||
{isRunningSimulation && (
|
||||
<SimulationLoadingOverlay
|
||||
status="RUNNING"
|
||||
progress={undefined}
|
||||
progress={simulationProgress}
|
||||
/>
|
||||
)}
|
||||
|
||||
@ -1179,7 +1222,14 @@ export function OilSpillView() {
|
||||
setPredictionTime(params.predictionTime)
|
||||
setIncidentCoord(params.incidentCoord)
|
||||
setSelectedModels(params.selectedModels)
|
||||
handleRunSimulation()
|
||||
handleRunSimulation({
|
||||
models: params.selectedModels,
|
||||
oilType: params.oilType,
|
||||
spillAmount: params.spillAmount,
|
||||
spillType: params.spillType,
|
||||
predictionTime: params.predictionTime,
|
||||
incidentCoord: params.incidentCoord,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
|
||||
|
||||
@ -49,7 +49,6 @@ const PredictionInputSection = ({
|
||||
isRunningSimulation,
|
||||
selectedModels,
|
||||
onModelsChange,
|
||||
visibleModels,
|
||||
onVisibleModelsChange,
|
||||
hasResults,
|
||||
predictionTime,
|
||||
@ -393,20 +392,17 @@ const PredictionInputSection = ({
|
||||
] as const).map(m => (
|
||||
<div
|
||||
key={m.id}
|
||||
className={`prd-mc ${(hasResults ? (visibleModels ?? selectedModels) : selectedModels).has(m.id) ? 'on' : ''} cursor-pointer`}
|
||||
className={`prd-mc ${selectedModels.has(m.id) ? 'on' : ''} cursor-pointer`}
|
||||
onClick={() => {
|
||||
if (!m.ready) {
|
||||
alert(`${m.id} 모델은 현재 준비중입니다.`)
|
||||
return
|
||||
}
|
||||
if (hasResults && onVisibleModelsChange) {
|
||||
const next = new Set(visibleModels ?? selectedModels)
|
||||
if (next.has(m.id)) { next.delete(m.id) } else { next.add(m.id) }
|
||||
onVisibleModelsChange(next)
|
||||
} else {
|
||||
const next = new Set(selectedModels)
|
||||
if (next.has(m.id)) { next.delete(m.id) } else { next.add(m.id) }
|
||||
onModelsChange(next)
|
||||
if (hasResults && onVisibleModelsChange) {
|
||||
onVisibleModelsChange(new Set(next))
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@ -1,86 +0,0 @@
|
||||
import { useQuery, useQueries } from '@tanstack/react-query';
|
||||
import { api } from '@common/services/api';
|
||||
import type { SimulationStatusResponse } from '../services/predictionApi';
|
||||
|
||||
export const useSimulationStatus = (execSn: number | null) => {
|
||||
return useQuery<SimulationStatusResponse>({
|
||||
queryKey: ['simulationStatus', execSn],
|
||||
queryFn: () => api.get<SimulationStatusResponse>(`/simulation/status/${execSn}`).then(r => r.data),
|
||||
enabled: execSn !== null,
|
||||
refetchInterval: (query) => {
|
||||
const status = query.state.data?.status;
|
||||
if (status === 'DONE' || status === 'ERROR') return false;
|
||||
return 3000;
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export interface ModelExecRef {
|
||||
model: string;
|
||||
execSn: number;
|
||||
}
|
||||
|
||||
interface MultiSimulationStatus {
|
||||
allDone: boolean;
|
||||
anyError: boolean;
|
||||
isLoading: boolean;
|
||||
results: Map<string, SimulationStatusResponse>;
|
||||
errors: Map<string, string>;
|
||||
}
|
||||
|
||||
export const useMultiSimulationStatus = (execSns: ModelExecRef[]): MultiSimulationStatus => {
|
||||
const queries = useQueries({
|
||||
queries: execSns.map(({ model, execSn }) => ({
|
||||
queryKey: ['simulationStatus', execSn],
|
||||
queryFn: () =>
|
||||
api.get<SimulationStatusResponse>(`/simulation/status/${execSn}`).then(r => r.data),
|
||||
enabled: execSns.length > 0,
|
||||
refetchInterval: (query: { state: { data?: SimulationStatusResponse } }) => {
|
||||
const status = query.state.data?.status;
|
||||
if (status === 'DONE' || status === 'ERROR') return false;
|
||||
return 3000;
|
||||
},
|
||||
meta: { model },
|
||||
})),
|
||||
});
|
||||
|
||||
if (execSns.length === 0) {
|
||||
return {
|
||||
allDone: false,
|
||||
anyError: false,
|
||||
isLoading: false,
|
||||
results: new Map(),
|
||||
errors: new Map(),
|
||||
};
|
||||
}
|
||||
|
||||
const results = new Map<string, SimulationStatusResponse>();
|
||||
const errors = new Map<string, string>();
|
||||
|
||||
execSns.forEach(({ model }, index) => {
|
||||
const query = queries[index];
|
||||
if (query.data) {
|
||||
results.set(model, query.data);
|
||||
}
|
||||
if (query.error) {
|
||||
const err = query.error;
|
||||
errors.set(model, err instanceof Error ? err.message : String(err));
|
||||
}
|
||||
});
|
||||
|
||||
const allDone =
|
||||
execSns.length > 0 &&
|
||||
execSns.every((_, index) => {
|
||||
const status = queries[index].data?.status;
|
||||
return status === 'DONE' || status === 'ERROR';
|
||||
});
|
||||
|
||||
const anyError = execSns.some((_, index) => {
|
||||
const status = queries[index].data?.status;
|
||||
return status === 'ERROR' || queries[index].isError;
|
||||
});
|
||||
|
||||
const isLoading = execSns.some((_, index) => queries[index].isLoading);
|
||||
|
||||
return { allDone, anyError, isLoading, results, errors };
|
||||
};
|
||||
@ -184,12 +184,32 @@ export interface SimulationStatusResponse {
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface RunModelSyncResult {
|
||||
model: string;
|
||||
execSn: number;
|
||||
status: 'DONE' | 'ERROR';
|
||||
trajectory?: OilParticle[];
|
||||
summary?: SimulationSummary;
|
||||
centerPoints?: CenterPoint[];
|
||||
windData?: WindPoint[][];
|
||||
hydrData?: (HydrDataStep | null)[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface RunModelSyncResponse {
|
||||
success: boolean;
|
||||
acdntSn: number | null;
|
||||
execSns: Array<{ model: string; execSn: number }>;
|
||||
results: RunModelSyncResult[];
|
||||
}
|
||||
|
||||
export interface TrajectoryResponse {
|
||||
trajectory: OilParticle[] | null;
|
||||
summary: SimulationSummary | null;
|
||||
centerPoints?: CenterPoint[];
|
||||
windData?: WindPoint[][];
|
||||
hydrData?: (HydrDataStep | null)[];
|
||||
windDataByModel?: Record<string, WindPoint[][]>;
|
||||
hydrDataByModel?: Record<string, (HydrDataStep | null)[]>;
|
||||
summaryByModel?: Record<string, SimulationSummary>;
|
||||
}
|
||||
|
||||
export const fetchAnalysisTrajectory = async (acdntSn: number): Promise<TrajectoryResponse> => {
|
||||
|
||||
@ -47,6 +47,7 @@ const OilSpreadMapPanel = ({ mapData, capturedImage, onCapture, onReset }: OilSp
|
||||
simulationStartTime={mapData.simulationStartTime || undefined}
|
||||
mapCaptureRef={captureRef}
|
||||
showOverlays={false}
|
||||
lightMode
|
||||
/>
|
||||
|
||||
{/* 캡처 이미지 오버레이 — 우측 상단 */}
|
||||
|
||||
@ -7,9 +7,6 @@ import OilSpreadMapPanel from './OilSpreadMapPanel';
|
||||
import { saveReport } from '../services/reportsApi';
|
||||
import {
|
||||
CATEGORIES,
|
||||
sampleOilData,
|
||||
sampleHnsData,
|
||||
sampleRescueData,
|
||||
type ReportCategory,
|
||||
type ReportSection,
|
||||
} from './reportTypes';
|
||||
@ -83,8 +80,8 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
report.incident.pollutant = oilPayload.pollution.oilType;
|
||||
report.incident.spillAmount = oilPayload.pollution.spillAmount;
|
||||
} else {
|
||||
report.incident.pollutant = sampleOilData.pollution.oilType;
|
||||
report.incident.spillAmount = sampleOilData.pollution.spillAmount;
|
||||
report.incident.pollutant = '';
|
||||
report.incident.spillAmount = '';
|
||||
}
|
||||
}
|
||||
if (activeCat === 0 && oilMapCaptured) {
|
||||
@ -102,7 +99,7 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
const handleDownload = () => {
|
||||
const secColor = cat.color === 'var(--cyan)' ? '#06b6d4' : cat.color === 'var(--orange)' ? '#f97316' : '#ef4444';
|
||||
const sectionHTML = activeSections.map(sec => {
|
||||
let content = `<p style="font-size:12px;color:#666;">${sec.desc}</p>`;
|
||||
let content = `<p style="font-size:12px;color:#999;">—</p>`;
|
||||
|
||||
// OIL 섹션에 실 데이터 삽입
|
||||
if (activeCat === 0) {
|
||||
@ -123,6 +120,15 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
content = `${mapImg}<table style="width:100%;border-collapse:collapse;font-size:12px;"><tr>${tds}</tr></table>`;
|
||||
}
|
||||
}
|
||||
if (activeCat === 0 && sec.id === 'oil-coastal') {
|
||||
if (oilPayload) {
|
||||
const coastLength = oilPayload.pollution.coastLength;
|
||||
const hasNoCoastal = !coastLength || coastLength === '—' || coastLength.startsWith('0.00');
|
||||
content = hasNoCoastal
|
||||
? `<p style="font-size:12px;">유출유의 해안 부착이 없습니다.</p>`
|
||||
: `<p style="font-size:12px;">최초 부착시간: <b>${oilPayload.coastal?.firstTime ?? '—'}</b> / 부착 해안길이: <b>${coastLength}</b></p>`;
|
||||
}
|
||||
}
|
||||
if (activeCat === 0 && oilPayload) {
|
||||
if (sec.id === 'oil-pollution') {
|
||||
const rows = [
|
||||
@ -322,9 +328,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
/>
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{[
|
||||
{ label: 'KOSPS', value: oilPayload?.spread.kosps || sampleOilData.spread.kosps, color: '#06b6d4' },
|
||||
{ label: 'OpenDrift', value: oilPayload?.spread.openDrift || sampleOilData.spread.openDrift, color: '#ef4444' },
|
||||
{ label: 'POSEIDON', value: oilPayload?.spread.poseidon || sampleOilData.spread.poseidon, color: '#f97316' },
|
||||
{ label: 'KOSPS', value: oilPayload?.spread.kosps || '—', color: '#06b6d4' },
|
||||
{ label: 'OpenDrift', value: oilPayload?.spread.openDrift || '—', color: '#ef4444' },
|
||||
{ label: 'POSEIDON', value: oilPayload?.spread.poseidon || '—', color: '#f97316' },
|
||||
].map((m, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
|
||||
<p className="text-[10px] text-text-3 font-korean mb-1">{m.label}</p>
|
||||
@ -345,9 +351,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
<colgroup><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /></colgroup>
|
||||
<tbody>
|
||||
{[
|
||||
['유출량', oilPayload?.pollution.spillAmount || sampleOilData.pollution.spillAmount, '풍화량', oilPayload?.pollution.weathered || sampleOilData.pollution.weathered],
|
||||
['해상잔유량', oilPayload?.pollution.seaRemain || sampleOilData.pollution.seaRemain, '오염해역면적', oilPayload?.pollution.pollutionArea || sampleOilData.pollution.pollutionArea],
|
||||
['연안부착량', oilPayload?.pollution.coastAttach || sampleOilData.pollution.coastAttach, '오염해안길이', oilPayload?.pollution.coastLength || sampleOilData.pollution.coastLength],
|
||||
['유출량', oilPayload?.pollution.spillAmount || '—', '풍화량', oilPayload?.pollution.weathered || '—'],
|
||||
['해상잔유량', oilPayload?.pollution.seaRemain || '—', '오염해역면적', oilPayload?.pollution.pollutionArea || '—'],
|
||||
['연안부착량', oilPayload?.pollution.coastAttach || '—', '오염해안길이', oilPayload?.pollution.coastLength || '—'],
|
||||
].map((row, i) => (
|
||||
<tr key={i} className="border-b border-border">
|
||||
<td className="px-4 py-3 text-[11px] text-text-3 font-korean bg-[rgba(255,255,255,0.02)]">{row[0]}</td>
|
||||
@ -361,20 +367,20 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
</>
|
||||
)}
|
||||
{sec.id === 'oil-sensitive' && (
|
||||
<>
|
||||
<p className="text-[11px] text-text-3 font-korean mb-3">반경 10 NM 기준</p>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{sampleOilData.sensitive.map((item, i) => (
|
||||
<span key={i} className="px-3 py-1.5 text-[11px] font-semibold rounded-md bg-bg-3 border border-border text-text-2 font-korean">{item.label}</span>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
<p className="text-[12px] text-text-3 font-korean italic">
|
||||
현재 민감자원 데이터가 없습니다.
|
||||
</p>
|
||||
)}
|
||||
{sec.id === 'oil-coastal' && (() => {
|
||||
const coastLength = oilPayload?.pollution.coastLength;
|
||||
const hasNoCoastal = oilPayload && (
|
||||
!coastLength || coastLength === '—' || coastLength.startsWith('0.00')
|
||||
if (!oilPayload) {
|
||||
return (
|
||||
<p className="text-[12px] text-text-3 font-korean italic">
|
||||
현재 해안 부착 데이터가 없습니다.
|
||||
</p>
|
||||
);
|
||||
}
|
||||
const coastLength = oilPayload.pollution.coastLength;
|
||||
const hasNoCoastal = !coastLength || coastLength === '—' || coastLength.startsWith('0.00');
|
||||
if (hasNoCoastal) {
|
||||
return (
|
||||
<p className="text-[12px] text-text-2 font-korean">
|
||||
@ -384,9 +390,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
}
|
||||
return (
|
||||
<p className="text-[12px] text-text-2 font-korean">
|
||||
최초 부착시간: <span className="font-semibold text-text-1">{oilPayload?.coastal?.firstTime ?? sampleOilData.coastal.firstTime}</span>
|
||||
최초 부착시간: <span className="font-semibold text-text-1">{oilPayload.coastal?.firstTime ?? '—'}</span>
|
||||
{' / '}
|
||||
부착 해안길이: <span className="font-semibold text-text-1">{coastLength || sampleOilData.coastal.coastLength}</span>
|
||||
부착 해안길이: <span className="font-semibold text-text-1">{coastLength}</span>
|
||||
</p>
|
||||
);
|
||||
})()}
|
||||
@ -399,20 +405,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
</div>
|
||||
)}
|
||||
{sec.id === 'oil-tide' && (
|
||||
<>
|
||||
<p className="text-[12px] text-text-2 font-korean">
|
||||
고조: <span className="font-semibold text-text-1">{sampleOilData.tide.highTide1}</span>
|
||||
{' / '}저조: <span className="font-semibold text-text-1">{sampleOilData.tide.lowTide}</span>
|
||||
{' / '}고조: <span className="font-semibold text-text-1">{sampleOilData.tide.highTide2}</span>
|
||||
<p className="text-[12px] text-text-3 font-korean italic">
|
||||
현재 조석·기상 데이터가 없습니다.
|
||||
</p>
|
||||
{oilPayload?.weather && (
|
||||
<p className="text-[11px] text-text-3 font-korean mt-2">
|
||||
기상: 풍향/풍속 <span className="text-text-2 font-semibold">{oilPayload.weather.windDir}</span>
|
||||
{' / '}파고 <span className="text-text-2 font-semibold">{oilPayload.weather.waveHeight}</span>
|
||||
{' / '}기온 <span className="text-text-2 font-semibold">{oilPayload.weather.temp}</span>
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* ── HNS 대기확산 섹션들 ── */}
|
||||
@ -432,7 +427,7 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
)}
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{[
|
||||
{ label: hnsPayload?.atm.model || 'ALOHA', value: hnsPayload?.atm.maxDistance || sampleHnsData.atm.aloha, color: '#f97316', desc: '최대 확산거리' },
|
||||
{ label: hnsPayload?.atm.model || 'ALOHA', value: hnsPayload?.atm.maxDistance || '—', color: '#f97316', desc: '최대 확산거리' },
|
||||
{ label: '최대 농도', value: hnsPayload?.maxConcentration || '—', color: '#ef4444', desc: '지상 1.5m 기준' },
|
||||
{ label: 'AEGL-1 면적', value: hnsPayload?.aeglAreas.aegl1 || '—', color: '#06b6d4', desc: '확산 영향 면적' },
|
||||
].map((m, i) => (
|
||||
@ -448,9 +443,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
{sec.id === 'hns-hazard' && (
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{[
|
||||
{ label: 'AEGL-3 구역', value: hnsPayload?.hazard.aegl3 || sampleHnsData.hazard.erpg3, area: hnsPayload?.aeglAreas.aegl3, color: '#ef4444', desc: '생명 위협' },
|
||||
{ label: 'AEGL-2 구역', value: hnsPayload?.hazard.aegl2 || sampleHnsData.hazard.erpg2, area: hnsPayload?.aeglAreas.aegl2, color: '#f97316', desc: '건강 피해' },
|
||||
{ label: 'AEGL-1 구역', value: hnsPayload?.hazard.aegl1 || sampleHnsData.hazard.evacuation, area: hnsPayload?.aeglAreas.aegl1, color: '#eab308', desc: '불쾌감' },
|
||||
{ label: 'AEGL-3 구역', value: hnsPayload?.hazard.aegl3 || '—', area: hnsPayload?.aeglAreas.aegl3, color: '#ef4444', desc: '생명 위협' },
|
||||
{ label: 'AEGL-2 구역', value: hnsPayload?.hazard.aegl2 || '—', area: hnsPayload?.aeglAreas.aegl2, color: '#f97316', desc: '건강 피해' },
|
||||
{ label: 'AEGL-1 구역', value: hnsPayload?.hazard.aegl1 || '—', area: hnsPayload?.aeglAreas.aegl1, color: '#eab308', desc: '불쾌감' },
|
||||
].map((h, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
|
||||
<p className="text-[9px] font-bold font-korean mb-1" style={{ color: h.color }}>{h.label}</p>
|
||||
@ -464,10 +459,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
{sec.id === 'hns-substance' && (
|
||||
<div className="grid grid-cols-2 gap-2 text-[11px]">
|
||||
{[
|
||||
{ k: '물질명', v: hnsPayload?.substance.name || sampleHnsData.substance.name },
|
||||
{ k: 'UN번호', v: hnsPayload?.substance.un || sampleHnsData.substance.un },
|
||||
{ k: 'CAS번호', v: hnsPayload?.substance.cas || sampleHnsData.substance.cas },
|
||||
{ k: '위험등급', v: hnsPayload?.substance.class || sampleHnsData.substance.class },
|
||||
{ k: '물질명', v: hnsPayload?.substance.name || '—' },
|
||||
{ k: 'UN번호', v: hnsPayload?.substance.un || '—' },
|
||||
{ k: 'CAS번호', v: hnsPayload?.substance.cas || '—' },
|
||||
{ k: '위험등급', v: hnsPayload?.substance.class || '—' },
|
||||
].map((r, i) => (
|
||||
<div key={i} className="flex justify-between px-3 py-2 bg-bg-1 rounded border border-border">
|
||||
<span className="text-text-3 font-korean">{r.k}</span>
|
||||
@ -476,25 +471,21 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
))}
|
||||
<div className="col-span-2 flex justify-between px-3 py-2 bg-bg-1 rounded border border-[rgba(239,68,68,0.3)]">
|
||||
<span className="text-text-3 font-korean">독성기준</span>
|
||||
<span className="text-[var(--red)] font-semibold font-mono text-[10px]">{hnsPayload?.substance.toxicity || sampleHnsData.substance.toxicity}</span>
|
||||
<span className="text-[var(--red)] font-semibold font-mono text-[10px]">{hnsPayload?.substance.toxicity || '—'}</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{sec.id === 'hns-ppe' && (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{sampleHnsData.ppe.map((item, i) => (
|
||||
<span key={i} className="px-3 py-1.5 text-[11px] font-semibold rounded-md border text-text-2 font-korean" style={{ background: 'rgba(249,115,22,0.06)', borderColor: 'rgba(249,115,22,0.2)' }}>
|
||||
🛡 {item}
|
||||
</span>
|
||||
))}
|
||||
<span className="text-text-3 font-korean text-[11px]">—</span>
|
||||
</div>
|
||||
)}
|
||||
{sec.id === 'hns-facility' && (
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{[
|
||||
{ label: '인근 학교', value: `${sampleHnsData.facility.schools}개소`, icon: '🏫' },
|
||||
{ label: '의료시설', value: `${sampleHnsData.facility.hospitals}개소`, icon: '🏥' },
|
||||
{ label: '주변 인구', value: sampleHnsData.facility.population, icon: '👥' },
|
||||
{ label: '인근 학교', value: '—', icon: '🏫' },
|
||||
{ label: '의료시설', value: '—', icon: '🏥' },
|
||||
{ label: '주변 인구', value: '—', icon: '👥' },
|
||||
].map((f, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
|
||||
<div className="text-[18px] mb-1">{f.icon}</div>
|
||||
@ -512,10 +503,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
{sec.id === 'hns-weather' && (
|
||||
<div className="grid grid-cols-4 gap-3">
|
||||
{[
|
||||
{ label: '풍향', value: hnsPayload?.weather.windDir || 'NE 42°', icon: '🌬' },
|
||||
{ label: '풍속', value: hnsPayload?.weather.windSpeed || '5.2 m/s', icon: '💨' },
|
||||
{ label: '대기안정도', value: hnsPayload?.weather.stability || 'D (중립)', icon: '🌡' },
|
||||
{ label: '기온', value: hnsPayload?.weather.temperature || '8.5°C', icon: '☀️' },
|
||||
{ label: '풍향', value: hnsPayload?.weather.windDir || '—', icon: '🌬' },
|
||||
{ label: '풍속', value: hnsPayload?.weather.windSpeed || '—', icon: '💨' },
|
||||
{ label: '대기안정도', value: hnsPayload?.weather.stability || '—', icon: '🌡' },
|
||||
{ label: '기온', value: hnsPayload?.weather.temperature || '—', icon: '☀️' },
|
||||
].map((w, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
|
||||
<div className="text-[16px] mb-0.5">{w.icon}</div>
|
||||
@ -530,10 +521,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
{sec.id === 'rescue-safety' && (
|
||||
<div className="grid grid-cols-4 gap-3">
|
||||
{[
|
||||
{ label: 'GM (복원력)', value: sampleRescueData.safety.gm, color: '#f97316' },
|
||||
{ label: '경사각 (Heel)', value: sampleRescueData.safety.heel, color: '#ef4444' },
|
||||
{ label: '트림 (Trim)', value: sampleRescueData.safety.trim, color: '#06b6d4' },
|
||||
{ label: '안전 상태', value: sampleRescueData.safety.status, color: '#f97316' },
|
||||
{ label: 'GM (복원력)', value: '—', color: '#f97316' },
|
||||
{ label: '경사각 (Heel)', value: '—', color: '#ef4444' },
|
||||
{ label: '트림 (Trim)', value: '—', color: '#06b6d4' },
|
||||
{ label: '안전 상태', value: '—', color: '#f97316' },
|
||||
].map((s, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
|
||||
<p className="text-[9px] text-text-3 font-korean mb-1">{s.label}</p>
|
||||
@ -544,26 +535,18 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
)}
|
||||
{sec.id === 'rescue-timeline' && (
|
||||
<div className="flex flex-col gap-2">
|
||||
{[
|
||||
{ time: '06:28', event: '충돌 발생 — ORIENTAL GLORY ↔ HAI FENG 168', color: '#ef4444' },
|
||||
{ time: '06:30', event: 'No.1P 탱크 파공, 벙커C유 유출 개시', color: '#f97316' },
|
||||
{ time: '06:35', event: 'VHF Ch.16 조난통신, 해경 출동 요청', color: '#eab308' },
|
||||
{ time: '07:15', event: '해경 3009함 현장 도착, 방제 개시', color: '#06b6d4' },
|
||||
].map((e, i) => (
|
||||
<div key={i} className="flex items-center gap-3 px-3 py-2 bg-bg-1 rounded border border-border">
|
||||
<span className="font-mono text-[11px] font-bold min-w-[40px]" style={{ color: e.color }}>{e.time}</span>
|
||||
<span className="text-[11px] text-text-2 font-korean">{e.event}</span>
|
||||
<div className="flex items-center gap-3 px-3 py-2 bg-bg-1 rounded border border-border">
|
||||
<span className="text-[11px] text-text-3 font-korean">—</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{sec.id === 'rescue-casualty' && (
|
||||
<div className="grid grid-cols-4 gap-3">
|
||||
{[
|
||||
{ label: '총원', value: sampleRescueData.casualty.total },
|
||||
{ label: '구조완료', value: sampleRescueData.casualty.rescued, color: '#22c55e' },
|
||||
{ label: '실종', value: sampleRescueData.casualty.missing, color: '#ef4444' },
|
||||
{ label: '부상', value: sampleRescueData.casualty.injured, color: '#f97316' },
|
||||
{ label: '총원', value: '—' },
|
||||
{ label: '구조완료', value: '—', color: '#22c55e' },
|
||||
{ label: '실종', value: '—', color: '#ef4444' },
|
||||
{ label: '부상', value: '—', color: '#f97316' },
|
||||
].map((c, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
|
||||
<p className="text-[9px] text-text-3 font-korean mb-1">{c.label}</p>
|
||||
@ -584,30 +567,18 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{sampleRescueData.resources.map((r, i) => (
|
||||
<tr key={i} className="border-b border-border">
|
||||
<td className="px-3 py-2 text-text-2 font-korean">{r.type}</td>
|
||||
<td className="px-3 py-2 text-text-1 font-mono font-semibold">{r.name}</td>
|
||||
<td className="px-3 py-2 text-text-2 text-center font-mono">{r.eta}</td>
|
||||
<td className="px-3 py-2 text-center">
|
||||
<span className="px-2 py-0.5 rounded text-[10px] font-semibold font-korean" style={{
|
||||
background: r.status === '투입중' ? 'rgba(34,197,94,0.15)' : r.status === '이동중' ? 'rgba(249,115,22,0.15)' : 'rgba(138,150,168,0.15)',
|
||||
color: r.status === '투입중' ? '#22c55e' : r.status === '이동중' ? '#f97316' : '#8a96a8',
|
||||
}}>
|
||||
{r.status}
|
||||
</span>
|
||||
</td>
|
||||
<tr className="border-b border-border">
|
||||
<td colSpan={4} className="px-3 py-3 text-center text-text-3 font-korean text-[11px]">—</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
)}
|
||||
{sec.id === 'rescue-grounding' && (
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{[
|
||||
{ label: '좌초 위험도', value: sampleRescueData.grounding.risk, color: '#ef4444' },
|
||||
{ label: '최근 천해', value: sampleRescueData.grounding.nearestShallow, color: '#f97316' },
|
||||
{ label: '현재 수심', value: sampleRescueData.grounding.depth, color: '#06b6d4' },
|
||||
{ label: '좌초 위험도', value: '—', color: '#ef4444' },
|
||||
{ label: '최근 천해', value: '—', color: '#f97316' },
|
||||
{ label: '현재 수심', value: '—', color: '#06b6d4' },
|
||||
].map((g, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
|
||||
<p className="text-[9px] text-text-3 font-korean mb-1">{g.label}</p>
|
||||
@ -619,10 +590,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
|
||||
{sec.id === 'rescue-weather' && (
|
||||
<div className="grid grid-cols-4 gap-3">
|
||||
{[
|
||||
{ label: '파고', value: '1.5 m', icon: '🌊' },
|
||||
{ label: '풍속', value: '5.2 m/s', icon: '🌬' },
|
||||
{ label: '조류', value: '1.2 kts NE', icon: '🌀' },
|
||||
{ label: '시정', value: '8 km', icon: '👁' },
|
||||
{ label: '파고', value: '—', icon: '🌊' },
|
||||
{ label: '풍속', value: '—', icon: '🌬' },
|
||||
{ label: '조류', value: '—', icon: '🌀' },
|
||||
{ label: '시정', value: '—', icon: '👁' },
|
||||
].map((w, i) => (
|
||||
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
|
||||
<div className="text-[16px] mb-0.5">{w.icon}</div>
|
||||
|
||||
@ -11,6 +11,7 @@ import {
|
||||
generateReportHTML,
|
||||
exportAsPDF,
|
||||
exportAsHWP,
|
||||
buildReportGetVal,
|
||||
typeColors,
|
||||
statusColors,
|
||||
analysisCatColors,
|
||||
@ -284,16 +285,7 @@ export function ReportsView() {
|
||||
onClick={() => {
|
||||
const tpl = templateTypes.find(t => t.id === previewReport.reportType)
|
||||
if (tpl) {
|
||||
const getVal = (key: string) => {
|
||||
if (key === 'author') return previewReport.author
|
||||
if (key.startsWith('incident.')) {
|
||||
const f = key.split('.')[1]
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (previewReport.incident as any)[f] || ''
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (previewReport as any)[key] || ''
|
||||
}
|
||||
const getVal = buildReportGetVal(previewReport)
|
||||
const html = generateReportHTML(tpl.label, { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction }, tpl.sections, getVal)
|
||||
exportAsPDF(html, previewReport.title || tpl.label)
|
||||
}
|
||||
@ -307,16 +299,7 @@ export function ReportsView() {
|
||||
onClick={() => {
|
||||
const tpl = templateTypes.find(t => t.id === previewReport.reportType) as TemplateType | undefined
|
||||
if (tpl) {
|
||||
const getVal = (key: string) => {
|
||||
if (key === 'author') return previewReport.author
|
||||
if (key.startsWith('incident.')) {
|
||||
const f = key.split('.')[1]
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (previewReport.incident as any)[f] || ''
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (previewReport as any)[key] || ''
|
||||
}
|
||||
const getVal = buildReportGetVal(previewReport)
|
||||
const meta = { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction }
|
||||
const filename = previewReport.title || tpl.label
|
||||
exportAsHWP(tpl.label, meta, tpl.sections, getVal, filename)
|
||||
|
||||
@ -75,15 +75,14 @@ export const templateTypes: TemplateType[] = [
|
||||
{ key: 'incident.pollutant', label: '유출유종', type: 'text' },
|
||||
{ key: 'incident.spillAmount', label: '유출량', type: 'text' },
|
||||
]},
|
||||
{ title: '3. 해양기상 현황', fields: [
|
||||
{ key: 'weatherSummary', label: '', type: 'textarea' },
|
||||
]},
|
||||
{ title: '4. 확산예측 결과', fields: [
|
||||
{ key: 'spreadResult', label: '', type: 'textarea' },
|
||||
]},
|
||||
{ title: '5. 민감자원 영향', fields: [
|
||||
{ key: 'sensitiveImpact', label: '', type: 'textarea' },
|
||||
]},
|
||||
{ title: '3. 조석 현황', fields: [{ key: '__tide', label: '', type: 'textarea' }] },
|
||||
{ title: '4. 해양기상 현황', fields: [{ key: '__weather', label: '', type: 'textarea' }] },
|
||||
{ title: '5. 확산예측 결과', fields: [{ key: '__spread', label: '', type: 'textarea' }] },
|
||||
{ title: '6. 민감자원 현황', fields: [{ key: '__sensitive', label: '', type: 'textarea' }] },
|
||||
{ title: '7. 방제 자원', fields: [{ key: '__vessels', label: '', type: 'textarea' }] },
|
||||
{ title: '8. 회수·처리 현황',fields: [{ key: '__recovery', label: '', type: 'textarea' }] },
|
||||
{ title: '9. 최종 결과', fields: [{ key: '__result', label: '', type: 'textarea' }] },
|
||||
{ title: '10. 분석 의견', fields: [{ key: 'analysis', label: '', type: 'textarea' }] },
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -86,3 +86,122 @@ export function inferAnalysisCategory(report: OilSpillReportData): string {
|
||||
if (t.includes('유출유') || t.includes('확산예측') || t.includes('민감자원') || t.includes('유출사고') || t.includes('오염') || t.includes('방제') || rt === '유출유 보고' || rt === '예측보고서') return '유출유 확산예측'
|
||||
return ''
|
||||
}
|
||||
|
||||
// ─── PDF/HWP 섹션 포맷 헬퍼 ─────────────────────────────────
|
||||
const TH = 'padding:6px 8px;border:1px solid #d1d5db;background:#f0f4f8;font-weight:600;font-size:11px;'
|
||||
const TD = 'padding:6px 8px;border:1px solid #d1d5db;font-size:11px;'
|
||||
const TABLE = 'width:100%;border-collapse:collapse;'
|
||||
|
||||
function formatTideTable(tide: OilSpillReportData['tide']): string {
|
||||
if (!tide?.length) return ''
|
||||
const header = `<tr><th style="${TH}">날짜</th><th style="${TH}">조형</th><th style="${TH}">간조1</th><th style="${TH}">만조1</th><th style="${TH}">간조2</th><th style="${TH}">만조2</th></tr>`
|
||||
const rows = tide.map(t =>
|
||||
`<tr><td style="${TD}">${t.date}</td><td style="${TD}">${t.tideType}</td><td style="${TD}">${t.lowTide1}</td><td style="${TD}">${t.highTide1}</td><td style="${TD}">${t.lowTide2}</td><td style="${TD}">${t.highTide2}</td></tr>`
|
||||
).join('')
|
||||
return `<table style="${TABLE}">${header}${rows}</table>`
|
||||
}
|
||||
|
||||
function formatWeatherTable(weather: OilSpillReportData['weather']): string {
|
||||
if (!weather?.length) return ''
|
||||
const header = `<tr><th style="${TH}">시각</th><th style="${TH}">풍향</th><th style="${TH}">풍속</th><th style="${TH}">유향</th><th style="${TH}">유속</th><th style="${TH}">파고</th></tr>`
|
||||
const rows = weather.map(w =>
|
||||
`<tr><td style="${TD}">${w.time}</td><td style="${TD}">${w.windDir}</td><td style="${TD}">${w.windSpeed}</td><td style="${TD}">${w.currentDir}</td><td style="${TD}">${w.currentSpeed}</td><td style="${TD}">${w.waveHeight}</td></tr>`
|
||||
).join('')
|
||||
return `<table style="${TABLE}">${header}${rows}</table>`
|
||||
}
|
||||
|
||||
function formatSpreadTable(spread: OilSpillReportData['spread']): string {
|
||||
if (!spread?.length) return ''
|
||||
const header = `<tr><th style="${TH}">경과시간</th><th style="${TH}">풍화량</th><th style="${TH}">해상잔유량</th><th style="${TH}">연안부착량</th><th style="${TH}">면적</th></tr>`
|
||||
const rows = spread.map(s =>
|
||||
`<tr><td style="${TD}">${s.elapsed}</td><td style="${TD}">${s.weathered}</td><td style="${TD}">${s.seaRemain}</td><td style="${TD}">${s.coastAttach}</td><td style="${TD}">${s.area}</td></tr>`
|
||||
).join('')
|
||||
return `<table style="${TABLE}">${header}${rows}</table>`
|
||||
}
|
||||
|
||||
function formatSensitiveTable(r: OilSpillReportData): string {
|
||||
const parts: string[] = []
|
||||
if (r.aquaculture?.length) {
|
||||
const h = `<tr><th style="${TH}">종류</th><th style="${TH}">면적</th><th style="${TH}">거리</th></tr>`
|
||||
const rows = r.aquaculture.map(a => `<tr><td style="${TD}">${a.type}</td><td style="${TD}">${a.area}</td><td style="${TD}">${a.distance}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">양식업</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.beaches?.length) {
|
||||
const h = `<tr><th style="${TH}">해수욕장명</th><th style="${TH}">거리</th></tr>`
|
||||
const rows = r.beaches.map(b => `<tr><td style="${TD}">${b.name}</td><td style="${TD}">${b.distance}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">해수욕장</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.markets?.length) {
|
||||
const h = `<tr><th style="${TH}">수산시장명</th><th style="${TH}">거리</th></tr>`
|
||||
const rows = r.markets.map(m => `<tr><td style="${TD}">${m.name}</td><td style="${TD}">${m.distance}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">수산시장</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.esi?.length) {
|
||||
const h = `<tr><th style="${TH}">코드</th><th style="${TH}">유형</th><th style="${TH}">길이</th></tr>`
|
||||
const rows = r.esi.map(e => `<tr><td style="${TD}">${e.code}</td><td style="${TD}">${e.type}</td><td style="${TD}">${e.length}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">ESI 해안선</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.species?.length) {
|
||||
const h = `<tr><th style="${TH}">분류</th><th style="${TH}">종명</th></tr>`
|
||||
const rows = r.species.map(s => `<tr><td style="${TD}">${s.category}</td><td style="${TD}">${s.species}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">보호생물</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.habitat?.length) {
|
||||
const h = `<tr><th style="${TH}">유형</th><th style="${TH}">면적</th></tr>`
|
||||
const rows = r.habitat.map(h2 => `<tr><td style="${TD}">${h2.type}</td><td style="${TD}">${h2.area}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">서식지</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
if (r.sensitivity?.length) {
|
||||
const h = `<tr><th style="${TH}">민감도</th><th style="${TH}">면적</th></tr>`
|
||||
const rows = r.sensitivity.map(s => `<tr><td style="${TD}">${s.level}</td><td style="${TD}">${s.area}</td></tr>`).join('')
|
||||
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">민감도 등급</p><table style="${TABLE}">${h}${rows}</table>`)
|
||||
}
|
||||
return parts.join('')
|
||||
}
|
||||
|
||||
function formatVesselsTable(vessels: OilSpillReportData['vessels']): string {
|
||||
if (!vessels?.length) return ''
|
||||
const header = `<tr><th style="${TH}">선명</th><th style="${TH}">기관</th><th style="${TH}">거리</th><th style="${TH}">속력</th><th style="${TH}">톤수</th><th style="${TH}">회수장비</th><th style="${TH}">오일붐</th></tr>`
|
||||
const rows = vessels.map(v =>
|
||||
`<tr><td style="${TD}">${v.name}</td><td style="${TD}">${v.org}</td><td style="${TD}">${v.dist}</td><td style="${TD}">${v.speed}</td><td style="${TD}">${v.ton}</td><td style="${TD}">${v.collectorType} ${v.collectorCap}</td><td style="${TD}">${v.boomType} ${v.boomLength}</td></tr>`
|
||||
).join('')
|
||||
return `<table style="${TABLE}">${header}${rows}</table>`
|
||||
}
|
||||
|
||||
function formatRecoveryTable(recovery: OilSpillReportData['recovery']): string {
|
||||
if (!recovery?.length) return ''
|
||||
const header = `<tr><th style="${TH}">선박명</th><th style="${TH}">회수 기간</th></tr>`
|
||||
const rows = recovery.map(r =>
|
||||
`<tr><td style="${TD}">${r.shipName}</td><td style="${TD}">${r.period}</td></tr>`
|
||||
).join('')
|
||||
return `<table style="${TABLE}">${header}${rows}</table>`
|
||||
}
|
||||
|
||||
function formatResultTable(result: OilSpillReportData['result']): string {
|
||||
if (!result) return ''
|
||||
return `<table style="${TABLE}">
|
||||
<tr><td style="${TH}">유출총량</td><td style="${TD}">${result.spillTotal}</td><td style="${TH}">풍화총량</td><td style="${TD}">${result.weatheredTotal}</td></tr>
|
||||
<tr><td style="${TH}">회수총량</td><td style="${TD}">${result.recoveredTotal}</td><td style="${TH}">해상잔유량</td><td style="${TD}">${result.seaRemainTotal}</td></tr>
|
||||
<tr><td style="${TH}">연안부착량</td><td style="${TD}" colspan="3">${result.coastAttachTotal}</td></tr>
|
||||
</table>`
|
||||
}
|
||||
|
||||
export function buildReportGetVal(report: OilSpillReportData) {
|
||||
return (key: string): string => {
|
||||
if (key === 'author') return report.author ?? ''
|
||||
if (key.startsWith('incident.')) {
|
||||
const f = key.split('.')[1]
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (report.incident as any)[f] || ''
|
||||
}
|
||||
if (key === '__tide') return formatTideTable(report.tide)
|
||||
if (key === '__weather') return formatWeatherTable(report.weather)
|
||||
if (key === '__spread') return formatSpreadTable(report.spread)
|
||||
if (key === '__sensitive') return formatSensitiveTable(report)
|
||||
if (key === '__vessels') return formatVesselsTable(report.vessels)
|
||||
if (key === '__recovery') return formatRecoveryTable(report.recovery)
|
||||
if (key === '__result') return formatResultTable(report.result)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (report as any)[key] || ''
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
__pycache__/
|
||||
stitch/
|
||||
|
||||
mx15hdi/Detect/Mask_result/
|
||||
mx15hdi/Detect/result/
|
||||
|
||||
mx15hdi/Georeference/Mask_Tif/
|
||||
mx15hdi/Georeference/Tif/
|
||||
|
||||
mx15hdi/Metadata/CSV/
|
||||
mx15hdi/Metadata/Image/Original_Images/
|
||||
|
||||
mx15hdi/Polygon/Shp/
|
||||
@ -1,376 +0,0 @@
|
||||
# wing-image-analysis Docker 사용 가이드
|
||||
|
||||
드론 영상 기반 유류 오염 분석 FastAPI 서버를 Docker 컨테이너로 빌드하고 실행하는 방법을 설명한다.
|
||||
|
||||
---
|
||||
|
||||
## 목차
|
||||
|
||||
1. [사전 요구사항](#1-사전-요구사항)
|
||||
2. [빠른 시작](#2-빠른-시작)
|
||||
3. [빌드 명령어](#3-빌드-명령어)
|
||||
4. [실행 명령어](#4-실행-명령어)
|
||||
5. [환경변수 설정](#5-환경변수-설정)
|
||||
6. [볼륨 구조](#6-볼륨-구조)
|
||||
7. [API 엔드포인트 사용 예시](#7-api-엔드포인트-사용-예시)
|
||||
8. [로그 확인 및 디버깅](#8-로그-확인-및-디버깅)
|
||||
9. [컨테이너 관리](#9-컨테이너-관리)
|
||||
10. [주의사항](#10-주의사항)
|
||||
11. [CPU 전용 환경 실행](#11-cpu-전용-환경-실행)
|
||||
|
||||
---
|
||||
|
||||
## 1. 사전 요구사항
|
||||
|
||||
| 항목 | 최소 버전 | 확인 명령어 |
|
||||
|------|----------|-------------|
|
||||
| Docker Engine | 24.0 이상 | `docker --version` |
|
||||
| Docker Compose | 2.20 이상 | `docker compose version` |
|
||||
| NVIDIA 드라이버 | 525 이상 (CUDA 12.1 지원) | `nvidia-smi` |
|
||||
| nvidia-container-toolkit | 최신 | `nvidia-ctk --version` |
|
||||
|
||||
### nvidia-container-toolkit 설치 (Ubuntu 기준)
|
||||
|
||||
```bash
|
||||
# GPG 키 및 저장소 추가
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
|
||||
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
|
||||
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
|
||||
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
|
||||
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
# Docker 런타임 설정 및 재시작
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
|
||||
### GPU 동작 확인
|
||||
|
||||
```bash
|
||||
docker run --rm --gpus all nvidia/cuda:12.1-base-ubuntu22.04 nvidia-smi
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 빠른 시작
|
||||
|
||||
```bash
|
||||
# 1. prediction/image/ 디렉토리로 이동
|
||||
cd prediction/image
|
||||
|
||||
# 2. 환경변수 파일 준비 (필요 시)
|
||||
cp .env.example .env
|
||||
|
||||
# 3. 빌드 + 실행 (백그라운드)
|
||||
docker compose up -d --build
|
||||
|
||||
# 4. 서버 상태 확인
|
||||
curl http://localhost:5001/docs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 빌드 명령어
|
||||
|
||||
### docker compose (권장)
|
||||
|
||||
```bash
|
||||
# 이미지 빌드만 수행 (실행 안 함)
|
||||
docker compose build
|
||||
|
||||
# 빌드 로그를 상세하게 출력
|
||||
docker compose build --progress=plain
|
||||
|
||||
# 캐시 없이 처음부터 빌드 (의존성 변경 시)
|
||||
docker compose build --no-cache
|
||||
```
|
||||
|
||||
### docker build (단독)
|
||||
|
||||
```bash
|
||||
# prediction/image/ 디렉토리에서 실행
|
||||
docker build -t wing-image-analysis:latest .
|
||||
|
||||
# 빌드 태그 지정
|
||||
docker build -t wing-image-analysis:1.0.0 .
|
||||
|
||||
# 캐시 없이 빌드
|
||||
docker build --no-cache -t wing-image-analysis:latest .
|
||||
```
|
||||
|
||||
> **참고**: 첫 빌드는 PyTorch base 이미지(약 8GB) + GDAL/Python 패키지 설치로 **30~60분** 소요될 수 있다.
|
||||
> 이후 빌드는 레이어 캐시로 수 분 내 완료된다.
|
||||
|
||||
---
|
||||
|
||||
## 4. 실행 명령어
|
||||
|
||||
### docker compose (권장)
|
||||
|
||||
```bash
|
||||
# 백그라운드 실행
|
||||
docker compose up -d
|
||||
|
||||
# 빌드 후 즉시 실행
|
||||
docker compose up -d --build
|
||||
|
||||
# 포그라운드 실행 (로그 바로 출력)
|
||||
docker compose up
|
||||
|
||||
# 중지
|
||||
docker compose down
|
||||
|
||||
# 중지 + 볼륨 삭제 (데이터 초기화 시)
|
||||
docker compose down -v
|
||||
```
|
||||
|
||||
### docker run (단독 — 테스트용)
|
||||
|
||||
```bash
|
||||
docker run --rm \
|
||||
--gpus all \
|
||||
-p 5001:5001 \
|
||||
--env-file .env \
|
||||
-v "$(pwd)/mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images" \
|
||||
wing-image-analysis:latest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 환경변수 설정
|
||||
|
||||
`.env.example`을 복사하여 `.env`를 생성한다.
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
| 변수 | 설명 | 기본값 |
|
||||
|------|------|--------|
|
||||
| `API_HOST` | 서버 바인드 주소 | `0.0.0.0` |
|
||||
| `API_PORT` | 서버 포트 | `5001` |
|
||||
|
||||
---
|
||||
|
||||
## 6. 볼륨 구조
|
||||
|
||||
컨테이너 내부 경로와 호스트 경로의 매핑이다. 이미지/결과 데이터는 컨테이너 외부에 저장되어 컨테이너를 재시작해도 유지된다.
|
||||
|
||||
```
|
||||
호스트 (prediction/image/) 컨테이너 (/app/)
|
||||
─────────────────────────────────────────────────────────────────────
|
||||
mx15hdi/Metadata/Image/Original_Images/ → mx15hdi/Metadata/Image/Original_Images/ ← 원본 이미지 입력
|
||||
mx15hdi/Metadata/CSV/ → mx15hdi/Metadata/CSV/ ← 메타데이터 출력
|
||||
mx15hdi/Georeference/Tif/ → mx15hdi/Georeference/Tif/ ← GeoTIFF 출력
|
||||
mx15hdi/Georeference/Mask_Tif/ → mx15hdi/Georeference/Mask_Tif/ ← 마스크 GeoTIFF
|
||||
mx15hdi/Polygon/Shp/ → mx15hdi/Polygon/Shp/ ← Shapefile 출력
|
||||
mx15hdi/Detect/result/ → mx15hdi/Detect/result/ ← 블렌딩 결과
|
||||
mx15hdi/Detect/Mask_result/ → mx15hdi/Detect/Mask_result/ ← 마스크 결과
|
||||
starsafire/Metadata/Image/Original_Images → starsafire/Metadata/Image/Original_Images ← 열화상 입력
|
||||
starsafire/{기타}/ → starsafire/{기타}/ ← 열화상 출력
|
||||
stitch/ → stitch/ ← 스티칭 결과
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. API 엔드포인트 사용 예시
|
||||
|
||||
서버 기동 후 `http://localhost:5001/docs`에서 Swagger UI로 전체 API를 확인할 수 있다.
|
||||
|
||||
### 7.1 전체 분석 파이프라인 실행
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:5001/run-script/ \
|
||||
-F "files=@/path/to/drone_image.jpg" \
|
||||
-F "camTy=mx15hdi" \
|
||||
-F "fileId=20240310_001"
|
||||
```
|
||||
|
||||
**응답 예시**:
|
||||
```json
|
||||
{
|
||||
"meta": "drone_image.jpg,37,30,0,126,55,0,...",
|
||||
"data": [
|
||||
{
|
||||
"classId": 2,
|
||||
"area": 1234.56,
|
||||
"volume": 0.1234,
|
||||
"note": "갈색",
|
||||
"thickness": 0.0001,
|
||||
"wkt": "POLYGON((...))"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 7.2 메타데이터 조회
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/get-metadata/mx15hdi/20240310_001
|
||||
```
|
||||
|
||||
### 7.3 원본 이미지 조회 (Base64)
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/get-original-image/mx15hdi/20240310_001
|
||||
```
|
||||
|
||||
### 7.4 GeoTIFF + 좌표 조회
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/get-image/mx15hdi/20240310_001
|
||||
```
|
||||
|
||||
### 7.5 이미지 스티칭
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:5001/stitch \
|
||||
-F "files=@photo1.jpg" \
|
||||
-F "files=@photo2.jpg" \
|
||||
-F "mode=drone"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 로그 확인 및 디버깅
|
||||
|
||||
```bash
|
||||
# 실시간 로그 출력
|
||||
docker logs wing-image-analysis -f
|
||||
|
||||
# 최근 100줄만 출력
|
||||
docker logs wing-image-analysis --tail 100
|
||||
|
||||
# 컨테이너 내부 쉘 접속
|
||||
docker exec -it wing-image-analysis bash
|
||||
|
||||
# GPU 사용 현황 확인 (컨테이너 내부)
|
||||
docker exec wing-image-analysis nvidia-smi
|
||||
|
||||
# Python 패키지 목록 확인
|
||||
docker exec wing-image-analysis pip list
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. 컨테이너 관리
|
||||
|
||||
```bash
|
||||
# 상태 확인
|
||||
docker compose ps
|
||||
|
||||
# 재시작
|
||||
docker compose restart
|
||||
|
||||
# 중지 (볼륨 유지)
|
||||
docker compose down
|
||||
|
||||
# 이미지 삭제
|
||||
docker rmi wing-image-analysis:latest
|
||||
|
||||
# 사용하지 않는 리소스 정리
|
||||
docker system prune -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. 주의사항
|
||||
|
||||
### GPU 자동 감지
|
||||
- 서버 기동 시 `torch.cuda.is_available()`로 GPU 유무를 자동 감지한다.
|
||||
- GPU가 있으면 `cuda:0`, 없으면 `cpu`로 자동 폴백된다.
|
||||
- 환경변수 `DEVICE`로 device를 명시 지정할 수 있다 (예: `DEVICE=cpu`, `DEVICE=cuda:1`).
|
||||
|
||||
### 첫 기동 시간
|
||||
- AI 모델 로드: 약 **10~30초** 소요 (GPU 메모리에 로딩)
|
||||
- 준비 완료 후 로그에 `Application startup complete` 메시지가 출력된다.
|
||||
|
||||
### workers=1 고정
|
||||
- GPU 모델은 프로세스 간 공유가 불가하므로 uvicorn workers는 반드시 `1`로 유지해야 한다.
|
||||
- 병렬 처리는 내부 `ThreadPoolExecutor`(max_workers=4)로 처리된다.
|
||||
|
||||
### 포트 충돌
|
||||
- 기본 포트 `5001`이 다른 서비스와 충돌하면 `docker-compose.yml`의 `ports` 항목을 수정한다:
|
||||
```yaml
|
||||
ports:
|
||||
- "5002:5001" # 호스트 5002 → 컨테이너 5001
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 11. CPU 전용 환경 실행
|
||||
|
||||
GPU(NVIDIA)가 없는 환경에서는 CPU 전용 설정을 사용한다.
|
||||
|
||||
### 사전 요구사항 (CPU 모드)
|
||||
|
||||
| 항목 | 최소 버전 | 확인 명령어 |
|
||||
|------|----------|-------------|
|
||||
| Docker Engine | 24.0 이상 | `docker --version` |
|
||||
| Docker Compose | 2.20 이상 | `docker compose version` |
|
||||
| NVIDIA 드라이버 | **불필요** | — |
|
||||
|
||||
### 빠른 시작 (CPU)
|
||||
|
||||
```bash
|
||||
# prediction/image/ 디렉토리로 이동
|
||||
cd prediction/image
|
||||
|
||||
# 환경변수 파일 준비 (필요 시)
|
||||
cp .env.example .env
|
||||
|
||||
# CPU 이미지 빌드 + 실행
|
||||
docker compose -f docker-compose.cpu.yml up -d --build
|
||||
|
||||
# 서버 상태 확인
|
||||
curl http://localhost:5001/docs
|
||||
```
|
||||
|
||||
### 빌드 명령어 (CPU)
|
||||
|
||||
```bash
|
||||
# CPU 이미지만 빌드
|
||||
docker compose -f docker-compose.cpu.yml build
|
||||
|
||||
# 캐시 없이 빌드
|
||||
docker compose -f docker-compose.cpu.yml build --no-cache
|
||||
```
|
||||
|
||||
> **참고**: CPU 기반 PyTorch 이미지는 GPU 이미지(~8GB) 대비 약 70% 용량이 절감된다.
|
||||
> 단, CPU 추론은 GPU 대비 처리 속도가 느리므로 대용량 이미지 분석 시 시간이 더 소요된다.
|
||||
|
||||
### 실행 명령어 (CPU)
|
||||
|
||||
```bash
|
||||
# 백그라운드 실행
|
||||
docker compose -f docker-compose.cpu.yml up -d
|
||||
|
||||
# 포그라운드 실행 (로그 바로 출력)
|
||||
docker compose -f docker-compose.cpu.yml up
|
||||
|
||||
# 중지
|
||||
docker compose -f docker-compose.cpu.yml down
|
||||
```
|
||||
|
||||
### 로컬 직접 실행 (Docker 없이)
|
||||
|
||||
```bash
|
||||
# GPU 있으면 자동으로 cuda:0 사용, 없으면 cpu로 폴백
|
||||
python api.py
|
||||
|
||||
# device 강제 지정
|
||||
DEVICE=cpu python api.py
|
||||
DEVICE=cuda:1 python api.py
|
||||
```
|
||||
|
||||
### GPU/CPU 모드 확인
|
||||
|
||||
서버 기동 로그에서 사용 device를 확인할 수 있다:
|
||||
|
||||
```
|
||||
[Inference] 사용 device: cpu ← CPU 모드
|
||||
[Inference] 사용 device: cuda:0 ← GPU 모드
|
||||
```
|
||||
@ -1,84 +0,0 @@
|
||||
# ==============================================================================
|
||||
# wing-image-analysis — 드론 영상 유류 분석 FastAPI 서버
|
||||
#
|
||||
# Base: PyTorch 1.9.1 + CUDA 11.1 + cuDNN 8
|
||||
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환 환경)
|
||||
# GPU: NVIDIA GPU 필수 (MMSegmentation 추론)
|
||||
# Port: 5001
|
||||
# ==============================================================================
|
||||
FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 시스템 패키지: GDAL / PROJ / GEOS (rasterio, geopandas 빌드 의존성)
|
||||
# libpq-dev: psycopg2-binary 런타임 의존성
|
||||
# libspatialindex-dev: geopandas 공간 인덱스
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gdal-bin \
|
||||
libgdal-dev \
|
||||
libproj-dev \
|
||||
libgeos-dev \
|
||||
libspatialindex-dev \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# rasterio는 GDAL 헤더 버전을 맞춰 빌드해야 한다
|
||||
ENV GDAL_VERSION=3.4.1
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# mmcv-full 1.4.3 — CUDA 11.1 + PyTorch 1.9.0 pre-built 휠
|
||||
# (소스 컴파일 없이 수 초 내 설치)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN pip install --no-cache-dir \
|
||||
mmcv-full==1.4.3 \
|
||||
-f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Python 의존성 설치
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 로컬 mmsegmentation 설치 (mx15hdi/Detect/mmsegmentation/)
|
||||
# 번들 소스를 먼저 복사한 뒤 editable 설치한다
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY mx15hdi/Detect/mmsegmentation/ /tmp/mmsegmentation/
|
||||
RUN pip install --no-cache-dir -e /tmp/mmsegmentation/
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 소스 코드 전체 복사
|
||||
# 대용량 데이터 디렉토리(Original_Images, result 등)는
|
||||
# docker-compose.yml의 볼륨 마운트로 외부에서 주입된다
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY . .
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# .dockerignore로 제외된 런타임 출력 디렉토리를 빈 폴더로 생성
|
||||
# (볼륨 마운트 전에도 경로가 존재해야 한다)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN mkdir -p \
|
||||
/app/stitch \
|
||||
/app/mx15hdi/Detect/Mask_result \
|
||||
/app/mx15hdi/Detect/result \
|
||||
/app/mx15hdi/Georeference/Mask_Tif \
|
||||
/app/mx15hdi/Georeference/Tif \
|
||||
/app/mx15hdi/Metadata/CSV \
|
||||
/app/mx15hdi/Metadata/Image/Original_Images \
|
||||
/app/mx15hdi/Polygon/Shp
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 런타임 설정
|
||||
# ------------------------------------------------------------------------------
|
||||
EXPOSE 5001
|
||||
|
||||
# workers=1: GPU 모델을 프로세스 하나에서만 로드 (메모리 공유 불가)
|
||||
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "5001", "--workers", "1"]
|
||||
@ -1,112 +0,0 @@
|
||||
# ==============================================================================
|
||||
# wing-image-analysis — 드론 영상 유류 분석 FastAPI 서버 (CPU 전용)
|
||||
#
|
||||
# Base: python:3.9-slim + PyTorch 1.9.0 CPU 빌드
|
||||
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환 환경)
|
||||
# python:3.9 필수 — numpy 1.26.4, geopandas 0.14.4가 Python >=3.9 요구
|
||||
# GPU: 불필요 (CPU 추론)
|
||||
# Port: 5001
|
||||
# ==============================================================================
|
||||
FROM python:3.9-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
DEVICE=cpu
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 시스템 패키지: GDAL / PROJ / GEOS (rasterio, geopandas 빌드 의존성)
|
||||
# libspatialindex-dev: geopandas 공간 인덱스
|
||||
# opencv-contrib-python-headless 런타임 SO 의존성 (python:3.9-slim에 미포함):
|
||||
# libgl1 — libGL.so.1
|
||||
# libglib2.0-0 — libgthread-2.0.so.0, libgobject-2.0.so.0, libglib-2.0.so.0
|
||||
# libsm6 — libSM.so.6
|
||||
# libxext6 — libXext.so.6
|
||||
# libxrender1 — libXrender.so.1
|
||||
# libgomp1 — libgomp.so.1 (OpenMP, numpy/opencv 병렬 처리)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gdal-bin \
|
||||
libgdal-dev \
|
||||
libproj-dev \
|
||||
libgeos-dev \
|
||||
libspatialindex-dev \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
libgomp1 \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# rasterio는 GDAL 헤더 버전을 맞춰 빌드해야 한다
|
||||
ENV GDAL_VERSION=3.4.1
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# GDAL Python 바인딩 (osgeo 모듈) — 시스템 GDAL 버전과 일치해야 한다
|
||||
# python:3.9-slim은 conda 없이 pip 환경이므로 명시적 설치 필요
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN pip install --no-cache-dir GDAL=="$(gdal-config --version)"
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# PyTorch 1.9.0 CPU 버전 설치
|
||||
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN pip install --no-cache-dir \
|
||||
torch==1.9.0+cpu \
|
||||
torchvision==0.10.0+cpu \
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# mmcv-full 1.4.3 CPU 휠 (CUDA ops 없는 경량 빌드, 추론에 충분)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN pip install --no-cache-dir \
|
||||
mmcv-full==1.4.3 \
|
||||
-f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Python 의존성 설치
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 로컬 mmsegmentation 설치 (mx15hdi/Detect/mmsegmentation/)
|
||||
# 번들 소스를 먼저 복사한 뒤 editable 설치한다
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY mx15hdi/Detect/mmsegmentation/ /tmp/mmsegmentation/
|
||||
RUN pip install --no-cache-dir -e /tmp/mmsegmentation/
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 소스 코드 전체 복사
|
||||
# 대용량 데이터 디렉토리(Original_Images, result 등)는
|
||||
# docker-compose.cpu.yml의 볼륨 마운트로 외부에서 주입된다
|
||||
# ------------------------------------------------------------------------------
|
||||
COPY . .
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# .dockerignore로 제외된 런타임 출력 디렉토리를 빈 폴더로 생성
|
||||
# (볼륨 마운트 전에도 경로가 존재해야 한다)
|
||||
# ------------------------------------------------------------------------------
|
||||
RUN mkdir -p \
|
||||
/app/stitch \
|
||||
/app/mx15hdi/Detect/Mask_result \
|
||||
/app/mx15hdi/Detect/result \
|
||||
/app/mx15hdi/Georeference/Mask_Tif \
|
||||
/app/mx15hdi/Georeference/Tif \
|
||||
/app/mx15hdi/Metadata/CSV \
|
||||
/app/mx15hdi/Metadata/Image/Original_Images \
|
||||
/app/mx15hdi/Polygon/Shp
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# 런타임 설정
|
||||
# ------------------------------------------------------------------------------
|
||||
EXPOSE 5001
|
||||
|
||||
# workers=1: 모델을 프로세스 하나에서만 로드 (메모리 공유 불가)
|
||||
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "5001", "--workers", "1"]
|
||||
@ -1,340 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||||
from fastapi.responses import Response, FileResponse
|
||||
import subprocess
|
||||
import rasterio
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import TAGS
|
||||
import io
|
||||
import base64
|
||||
from pyproj import Transformer
|
||||
from extract_data import get_metadata as get_meta
|
||||
from extract_data import get_oil_type as get_oil
|
||||
import time
|
||||
|
||||
from typing import List, Optional
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from collections import Counter
|
||||
|
||||
# mx15hdi 파이프라인 모듈 임포트를 위한 sys.path 설정
|
||||
_BASE_DIR = Path(__file__).parent
|
||||
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Detect'))
|
||||
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Metadata' / 'Scripts'))
|
||||
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Georeference' / 'Scripts'))
|
||||
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Polygon' / 'Scripts'))
|
||||
|
||||
from Inference import load_model, run_inference
|
||||
from Export_Metadata_mx15hdi import run_metadata_export
|
||||
from Create_Georeferenced_Images_nadir import run_georeference
|
||||
from Oilshape import run_oilshape
|
||||
|
||||
# AI 모델 (서버 시작 시 1회 로드)
|
||||
_model = None
|
||||
# CPU/GPU 바운드 작업용 스레드 풀
|
||||
_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""서버 시작 시 AI 모델을 1회 로드하고, 종료 시 해제한다."""
|
||||
global _model
|
||||
print("AI 모델 로딩 중 (epoch_165.pth)...")
|
||||
_model = load_model()
|
||||
print("AI 모델 로드 완료")
|
||||
yield
|
||||
_model = None
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def check_gps_info(image_path: str):
|
||||
# Pillow로 이미지 열기
|
||||
image = Image.open(image_path)
|
||||
|
||||
# EXIF 데이터 추출
|
||||
exifdata = image.getexif()
|
||||
|
||||
if not exifdata:
|
||||
print("EXIF 정보를 찾을 수 없습니다.")
|
||||
return False
|
||||
|
||||
# GPS 정보 추출
|
||||
gps_ifd = exifdata.get_ifd(0x8825) # GPS IFD 태그
|
||||
if not gps_ifd:
|
||||
print("GPS 정보를 찾을 수 없습니다.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_camera_info(image_file):
|
||||
# Pillow로 이미지 열기
|
||||
image = Image.open(image_file)
|
||||
|
||||
# EXIF 데이터 추출
|
||||
exifdata = image.getexif()
|
||||
|
||||
if not exifdata:
|
||||
print("EXIF 정보를 찾을 수 없습니다.")
|
||||
return False
|
||||
|
||||
for tag_id, value in exifdata.items():
|
||||
tag_name = TAGS.get(tag_id, tag_id)
|
||||
if tag_name == "Model":
|
||||
return value.strip() if isinstance(value, str) else value
|
||||
|
||||
|
||||
async def _run_mx15hdi_pipeline(file_id: str):
|
||||
"""
|
||||
mx15hdi 파이프라인을 in-process로 실행한다.
|
||||
- Step 1 (AI 추론) + Step 2 (메타데이터 추출) 병렬 실행
|
||||
- Step 3 (지리참조) → Step 4 (폴리곤 추출) 순차 실행
|
||||
- 중간 파일 I/O 없이 numpy 배열을 메모리로 전달
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Step 1 + Step 2 병렬 실행 — inference_cache 캡처
|
||||
inference_cache, _ = await asyncio.gather(
|
||||
loop.run_in_executor(_executor, run_inference, _model, file_id),
|
||||
loop.run_in_executor(_executor, run_metadata_export, file_id),
|
||||
)
|
||||
|
||||
# Step 3: Georeference — inference_cache 메모리로 전달, georef_cache 반환
|
||||
georef_cache = await loop.run_in_executor(
|
||||
_executor, run_georeference, file_id, inference_cache
|
||||
)
|
||||
|
||||
# Step 4: Polygon 추출 — georef_cache 메모리로 전달 (Mask_Tif 디스크 읽기 없음)
|
||||
await loop.run_in_executor(_executor, run_oilshape, file_id, georef_cache)
|
||||
|
||||
|
||||
# 전체 과정을 구동하는 api
|
||||
@app.post("/run-script/")
|
||||
async def run_script(
|
||||
# pollId: int = Form(...),
|
||||
camTy: str = Form(...),
|
||||
fileId: str = Form(...),
|
||||
image: UploadFile = File(...)
|
||||
):
|
||||
try:
|
||||
print("start")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if camTy not in ["mx15hdi", "starsafire"]:
|
||||
raise HTTPException(status_code=400, detail="string1 must be 'mx15hdi' or 'starsafire'")
|
||||
|
||||
# 저장할 이미지 경로 설정
|
||||
upload_dir = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
# 이미지 파일 저장
|
||||
image_path = os.path.join(upload_dir, image.filename)
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(await image.read())
|
||||
|
||||
gps_flage = check_gps_info(image_path)
|
||||
if not gps_flage:
|
||||
return {"detail": "GPS Infomation Not Found"}
|
||||
|
||||
if camTy == "mx15hdi":
|
||||
# in-process 파이프라인 실행 (모델 재로딩 없음, Step1+2 병렬)
|
||||
await _run_mx15hdi_pipeline(fileId)
|
||||
else:
|
||||
# starsafire: 기존 subprocess 방식 유지
|
||||
script_dir = os.path.join(os.getcwd(), camTy, "Main")
|
||||
script_file = "Combine_module.py"
|
||||
script_path = os.path.join(script_dir, script_file)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
raise HTTPException(status_code=404, detail="Script not found")
|
||||
|
||||
result = subprocess.run(
|
||||
["python", script_file, fileId],
|
||||
cwd=script_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300
|
||||
)
|
||||
print(f"Subprocess stdout: {result.stdout}")
|
||||
print(f"Subprocess stderr: {result.stderr}")
|
||||
|
||||
meta_string = get_meta(camTy, fileId)
|
||||
oil_data = get_oil(camTy, fileId)
|
||||
end_time = time.perf_counter()
|
||||
print(f"Run time: {end_time - start_time:.4f} sec")
|
||||
|
||||
return {
|
||||
"meta": meta_string,
|
||||
"data": oil_data
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise HTTPException(status_code=500, detail="Script execution timed out")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/get-metadata/{camTy}/{fileId}")
|
||||
async def get_metadata(camTy: str, fileId: str):
|
||||
try:
|
||||
meta_string = get_meta(camTy, fileId)
|
||||
oil_data = get_oil(camTy, fileId)
|
||||
|
||||
return {
|
||||
"meta": meta_string,
|
||||
"data": oil_data
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise HTTPException(status_code=500, detail="Script execution timed out")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/get-original-image/{camTy}/{fileId}")
|
||||
async def get_original_image(camTy: str, fileId: str):
|
||||
try:
|
||||
image_path = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
|
||||
files = os.listdir(image_path)
|
||||
target_file = [f for f in files if f.endswith(".png") or f.endswith(".jpg")]
|
||||
image_file = os.path.join(image_path, target_file[0])
|
||||
|
||||
with open(image_file, "rb") as origin_image:
|
||||
base64_string = base64.b64encode(origin_image.read()).decode("utf-8")
|
||||
print(base64_string[:100])
|
||||
|
||||
return base64_string
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/get-image/{camTy}/{fileId}")
|
||||
async def get_image(camTy: str, fileId: str):
|
||||
try:
|
||||
tif_file_path = os.path.join(camTy, "Georeference/Tif", fileId)
|
||||
files = os.listdir(tif_file_path)
|
||||
target_file = [f for f in files if f.endswith(".tif")]
|
||||
tif_file = os.path.join(tif_file_path, target_file[0])
|
||||
|
||||
with rasterio.open(tif_file) as dataset:
|
||||
crs = dataset.crs
|
||||
|
||||
bounds = dataset.bounds
|
||||
|
||||
if crs != "EPSG:4326":
|
||||
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
|
||||
minx, miny = transformer.transform(bounds.left, bounds.bottom)
|
||||
maxx, maxy = transformer.transform(bounds.right, bounds.top)
|
||||
|
||||
print(minx, miny, maxx, maxy)
|
||||
|
||||
data = dataset.read()
|
||||
if data.shape[0] == 1:
|
||||
image_data = data[0]
|
||||
else:
|
||||
image_data = np.moveaxis(data, 0, -1)
|
||||
|
||||
image = Image.fromarray(image_data)
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
|
||||
base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
print(base64_string[:100])
|
||||
return {
|
||||
"minLon": minx,
|
||||
"minLat": miny,
|
||||
"maxLon": maxx,
|
||||
"maxLat": maxy,
|
||||
"image": base64_string
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
PIC_GPS_SCRIPT = BASE_DIR / "pic_gps.py"
|
||||
|
||||
@app.post("/stitch")
|
||||
async def stitch(
|
||||
files: List[UploadFile] = File(..., description="합성할 이미지 파일들 (2장 이상)"),
|
||||
fileId: str = Form(...)
|
||||
):
|
||||
if len(files) < 2:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="최소 2장 이상의 이미지가 필요합니다."
|
||||
)
|
||||
|
||||
try:
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
upload_dir = BASE_DIR / "stitch" / fileId
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_list = []
|
||||
for idx, file in enumerate(files):
|
||||
|
||||
model = check_camera_info(file.file)
|
||||
model_list.append(model)
|
||||
|
||||
original_filename = file.filename or f"image_{idx}.jpg"
|
||||
filename = f"{model}_{idx:03d}_{original_filename}"
|
||||
file_path = upload_dir / filename
|
||||
|
||||
output_filename = f"stitched_{fileId}.jpg"
|
||||
output_path = upload_dir / output_filename
|
||||
|
||||
# 파일 저장
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
model_counter = Counter(model_list)
|
||||
most_common_model = model_counter.most_common(1)
|
||||
|
||||
cmd = [
|
||||
"python",
|
||||
str(PIC_GPS_SCRIPT),
|
||||
"--mode", "drone",
|
||||
"--input", str(upload_dir),
|
||||
"--out", str(output_path),
|
||||
"--model", most_common_model[0][0],
|
||||
"--enhance"
|
||||
]
|
||||
|
||||
print(cmd)
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300
|
||||
)
|
||||
|
||||
print(f"Subprocess stdout: {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
print(f"Subprocess stderr: {result.stderr}")
|
||||
raise HTTPException(status_code=500, detail=f"Script failed: {result.stderr}")
|
||||
|
||||
return FileResponse(
|
||||
path=str(output_path),
|
||||
media_type="image/jpeg",
|
||||
filename=output_filename
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5001)
|
||||
@ -1,46 +0,0 @@
|
||||
version: "3.9"
|
||||
|
||||
# CPU 전용 docker-compose 설정
|
||||
# GPU(nvidia-container-toolkit) 없이도 실행 가능
|
||||
# 실행: docker compose -f docker-compose.cpu.yml up -d --build
|
||||
|
||||
services:
|
||||
image-analysis:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.cpu
|
||||
image: wing-image-analysis:cpu
|
||||
container_name: wing-image-analysis
|
||||
ports:
|
||||
- "5001:5001"
|
||||
environment:
|
||||
- DEVICE=cpu
|
||||
|
||||
volumes:
|
||||
# ── mx15hdi (EO 드론 카메라) ────────────────────────────────────────
|
||||
# 입력: 업로드된 원본 이미지
|
||||
- ./mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images
|
||||
# 출력: 메타데이터 CSV
|
||||
- ./mx15hdi/Metadata/CSV:/app/mx15hdi/Metadata/CSV
|
||||
# 출력: 지리참조 GeoTIFF (컬러 / 마스크)
|
||||
- ./mx15hdi/Georeference/Tif:/app/mx15hdi/Georeference/Tif
|
||||
- ./mx15hdi/Georeference/Mask_Tif:/app/mx15hdi/Georeference/Mask_Tif
|
||||
# 출력: 유류 폴리곤 Shapefile
|
||||
- ./mx15hdi/Polygon/Shp:/app/mx15hdi/Polygon/Shp
|
||||
# 출력: 블렌딩 추론 결과 / 마스크 이미지
|
||||
- ./mx15hdi/Detect/result:/app/mx15hdi/Detect/result
|
||||
- ./mx15hdi/Detect/Mask_result:/app/mx15hdi/Detect/Mask_result
|
||||
# ── starsafire (열화상 카메라) ──────────────────────────────────────
|
||||
- ./starsafire/Metadata/Image/Original_Images:/app/starsafire/Metadata/Image/Original_Images
|
||||
- ./starsafire/Metadata/CSV:/app/starsafire/Metadata/CSV
|
||||
- ./starsafire/Georeference/Tif:/app/starsafire/Georeference/Tif
|
||||
- ./starsafire/Georeference/Mask_Tif:/app/starsafire/Georeference/Mask_Tif
|
||||
- ./starsafire/Polygon/Shp:/app/starsafire/Polygon/Shp
|
||||
- ./starsafire/Detect/result:/app/starsafire/Detect/result
|
||||
- ./starsafire/Detect/Mask_result:/app/starsafire/Detect/Mask_result
|
||||
# ── 스티칭 결과 ─────────────────────────────────────────────────────
|
||||
- ./stitch:/app/stitch
|
||||
|
||||
# GPU deploy 섹션 없음 — CPU 전용 실행
|
||||
|
||||
restart: unless-stopped
|
||||
@ -1,47 +0,0 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
image-analysis:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: wing-image-analysis:latest
|
||||
container_name: wing-image-analysis
|
||||
ports:
|
||||
- "5001:5001"
|
||||
|
||||
volumes:
|
||||
# ── mx15hdi (EO 드론 카메라) ────────────────────────────────────────
|
||||
# 입력: 업로드된 원본 이미지
|
||||
- ./mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images
|
||||
# 출력: 메타데이터 CSV
|
||||
- ./mx15hdi/Metadata/CSV:/app/mx15hdi/Metadata/CSV
|
||||
# 출력: 지리참조 GeoTIFF (컬러 / 마스크)
|
||||
- ./mx15hdi/Georeference/Tif:/app/mx15hdi/Georeference/Tif
|
||||
- ./mx15hdi/Georeference/Mask_Tif:/app/mx15hdi/Georeference/Mask_Tif
|
||||
# 출력: 유류 폴리곤 Shapefile
|
||||
- ./mx15hdi/Polygon/Shp:/app/mx15hdi/Polygon/Shp
|
||||
# 출력: 블렌딩 추론 결과 / 마스크 이미지
|
||||
- ./mx15hdi/Detect/result:/app/mx15hdi/Detect/result
|
||||
- ./mx15hdi/Detect/Mask_result:/app/mx15hdi/Detect/Mask_result
|
||||
# ── starsafire (열화상 카메라) ──────────────────────────────────────
|
||||
- ./starsafire/Metadata/Image/Original_Images:/app/starsafire/Metadata/Image/Original_Images
|
||||
- ./starsafire/Metadata/CSV:/app/starsafire/Metadata/CSV
|
||||
- ./starsafire/Georeference/Tif:/app/starsafire/Georeference/Tif
|
||||
- ./starsafire/Georeference/Mask_Tif:/app/starsafire/Georeference/Mask_Tif
|
||||
- ./starsafire/Polygon/Shp:/app/starsafire/Polygon/Shp
|
||||
- ./starsafire/Detect/result:/app/starsafire/Detect/result
|
||||
- ./starsafire/Detect/Mask_result:/app/starsafire/Detect/Mask_result
|
||||
# ── 스티칭 결과 ─────────────────────────────────────────────────────
|
||||
- ./stitch:/app/stitch
|
||||
|
||||
# NVIDIA GPU 할당 (nvidia-container-toolkit 필수)
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
restart: unless-stopped
|
||||
@ -1,97 +0,0 @@
|
||||
import csv
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import geopandas as gpd
|
||||
import json
|
||||
|
||||
def get_metadata(camTy: str, fileId: str):
|
||||
|
||||
# CSV 파일 경로 설정
|
||||
# base_dir = "mx15hdi" if pollId == "1" else "starsafire"
|
||||
if camTy == "mx15hdi":
|
||||
csv_path = f"{camTy}/Metadata/CSV/{fileId}/mx15hdi_interpolation.csv"
|
||||
elif camTy == "starsafire":
|
||||
csv_path = f"{camTy}/Metadata/CSV/{fileId}/Metadata_Extracted.csv"
|
||||
|
||||
try:
|
||||
# CSV 파일 읽기
|
||||
with open(csv_path, 'r', newline='', encoding='utf-8-sig') as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
next(reader, None)
|
||||
row = next(reader, None)
|
||||
return ','.join(row)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"CSV file not found: {csv_path}")
|
||||
raise
|
||||
except ValueError as e:
|
||||
print(f"Value error: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error processing CSV: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_oil_type(camTy: str, fileId: str):
|
||||
# Shapefile 경로 설정
|
||||
path = f"{camTy}/Polygon/Shp/{fileId}"
|
||||
shp_file = list(Path(path).glob("*.shp"))
|
||||
if not shp_file:
|
||||
return []
|
||||
shp_path = f"{camTy}/Polygon/Shp/{fileId}/{shp_file[0].name}"
|
||||
print(shp_path)
|
||||
# if camTy == "mx15hdi":
|
||||
# fileSub = f"{Path(fileName).stem}_gsd"
|
||||
# elif camTy == "starsafire":
|
||||
# fileSub = f"{Path(fileName).stem}"
|
||||
|
||||
# shp_path = f"{camTy}/Polygon/Shp/{fileId}/{fileSub}.shp"
|
||||
|
||||
# 두께 정보
|
||||
class_thickness_mm = {
|
||||
1: 1.0, # Black oil (Emulsion)
|
||||
2: 0.1, # Brown oil (Crude)
|
||||
3: 0.0003, # Rainbow oil (Slick)
|
||||
4: 0.0001 # Silver oil (Slick)
|
||||
}
|
||||
# 알고리즘 정보
|
||||
algorithm = {
|
||||
1: "검정",
|
||||
2: "갈색",
|
||||
3: "무지개",
|
||||
4: "은색"
|
||||
}
|
||||
|
||||
try:
|
||||
# Shapefile 읽기
|
||||
gdf = gpd.read_file(shp_path)
|
||||
if gdf.crs != "epsg:4326":
|
||||
gdf = gdf.to_crs("epsg:4326")
|
||||
|
||||
# 데이터 준비
|
||||
data = []
|
||||
for _, row in gdf.iterrows():
|
||||
class_id = row.get('class_id', None)
|
||||
area_m2 = row.get('area_m2', None)
|
||||
volume_m3 = row.get('volume_m3', None)
|
||||
note = row.get('note', None)
|
||||
thickness_m = class_thickness_mm.get(class_id, 0) / 1000.0
|
||||
geom_wkt = row.geometry.wkt if row.geometry else None
|
||||
result = {
|
||||
"classId": algorithm.get(class_id, 0),
|
||||
"area": area_m2,
|
||||
"volume": volume_m3,
|
||||
"note": note,
|
||||
"thickness": thickness_m,
|
||||
"wkt": geom_wkt
|
||||
}
|
||||
data.append(result)
|
||||
|
||||
return data
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Shapefile not found: {shp_path}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error processing shapefile or database: {str(e)}")
|
||||
raise
|
||||
@ -1,238 +0,0 @@
|
||||
# 이미지 업로드 유류 분석 기능 구현 계획
|
||||
|
||||
## Context
|
||||
|
||||
드론/항공 촬영 이미지를 업로드하면 AI 세그멘테이션으로 유류 확산 정보(위치·유종·면적·부피)를 자동 추출하고, 결과를 DB에 저장한 뒤 예측정보 입력 폼에 자동 채워주는 기능이다.
|
||||
이미지 분석 서버(`prediction/image/api.py`, FastAPI, 포트 5001)는 이미 구현되어 있으며, 프론트↔백엔드↔이미지 분석 서버 연동 및 결과 자동 채우기를 구현한다.
|
||||
|
||||
---
|
||||
|
||||
## 전체 흐름
|
||||
|
||||
```
|
||||
[프론트] 이미지 선택 → 분석 요청 버튼
|
||||
↓ POST /api/prediction/image-analyze (multipart: image)
|
||||
[백엔드]
|
||||
├─ fileId = UUID 생성
|
||||
├─ camTy = "mx15hdi" (하드코딩, 추후 이미지 EXIF 카메라 정보로 자동 판별 예정)
|
||||
├─ 이미지 분석 서버로 전달 POST http://IMAGE_API_URL/run-script/
|
||||
├─ 응답 파싱: meta(위경도 DMS→십진수 변환), data[0].classId→유종
|
||||
├─ ACDNT INSERT (lat/lon/임시사고명)
|
||||
├─ SPIL_DATA INSERT (유종/면적/img_rslt_data JSONB)
|
||||
└─ 응답: { acdntSn, lat, lon, oilType, area, volume }
|
||||
↓
|
||||
[프론트] 폼 자동 채우기 (좌표·유종·유출량)
|
||||
→ 사용자가 나머지 입력 후 "확산예측 실행"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 구현 단계
|
||||
|
||||
### Step 1 — DB 마이그레이션 (`database/migration/017_spil_img_rslt.sql`)
|
||||
|
||||
`SPIL_DATA` 테이블에 이미지 분석 결과 컬럼 추가.
|
||||
|
||||
```sql
|
||||
ALTER TABLE wing.spil_data
|
||||
ADD COLUMN IF NOT EXISTS img_rslt_data JSONB;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2 — 백엔드: 이미지 분석 엔드포인트
|
||||
|
||||
**파일**: `backend/src/prediction/predictionRouter.ts` (라우트 등록)
|
||||
**신규 파일**: `backend/src/prediction/imageAnalyzeService.ts`
|
||||
|
||||
#### 엔드포인트
|
||||
|
||||
```
|
||||
POST /api/prediction/image-analyze
|
||||
Content-Type: multipart/form-data
|
||||
Body: image (file)
|
||||
```
|
||||
|
||||
#### `imageAnalyzeService.ts` 핵심 로직
|
||||
|
||||
```typescript
|
||||
// 1. fileId 생성 (crypto.randomUUID)
|
||||
|
||||
// 2. 이미지 분석 서버 호출
|
||||
// camTy는 현재 "mx15hdi"로 하드코딩한다.
|
||||
// TODO: 추후 이미지 EXIF에서 카메라 모델명을 읽어 camTy를 자동 판별하는 로직을
|
||||
// 이미지 분석 서버(api.py)에 추가할 예정이다. (check_camera_info 함수 활용)
|
||||
// FormData: { camTy: 'mx15hdi', fileId, image }
|
||||
// → POST ${IMAGE_API_URL}/run-script/
|
||||
// 응답: { meta: string, data: OilPolygon[] }
|
||||
|
||||
// 3. meta 문자열 파싱 (mx15hdi CSV 컬럼 순서 사용)
|
||||
// [Filename, Tlat_d, Tlat_m, Tlat_s, Tlon_d, Tlon_m, Tlon_s, ...]
|
||||
// DMS → 십진수: d + m/60 + s/3600
|
||||
|
||||
// 4. 유종 매핑 (data[0].classId → UI 유종명)
|
||||
// classId → oilType: { '검정': '벙커C유', '갈색': '벙커C유', '무지개': '경유', '은색': '등유' }
|
||||
|
||||
// 5. ACDNT INSERT (임시 사고명 = "이미지분석_YYYY-MM-DD HH:mm", lat, lon, occurredAt = 촬영시각)
|
||||
// 6. SPIL_DATA INSERT (acdntSn, matTyCd, matVol=data[0].volume, imgRsltData=JSON.stringify(response))
|
||||
|
||||
// 7. 반환
|
||||
interface ImageAnalyzeResult {
|
||||
acdntSn: number;
|
||||
lat: number;
|
||||
lon: number;
|
||||
oilType: string; // UI 유종명 (벙커C유 등)
|
||||
area: number; // m²
|
||||
volume: number; // m³
|
||||
fileId: string;
|
||||
}
|
||||
```
|
||||
|
||||
#### 환경변수 추가 (`backend/.env`)
|
||||
|
||||
```
|
||||
IMAGE_API_URL=http://localhost:5001
|
||||
```
|
||||
|
||||
#### 에러 처리
|
||||
|
||||
| 조건 | 응답 |
|
||||
|------|------|
|
||||
| 이미지에 GPS EXIF 없음 | 422 `{ error: 'GPS_NOT_FOUND' }` |
|
||||
| 이미지 서버 타임아웃(300s) | 504 |
|
||||
|
||||
---
|
||||
|
||||
### Step 3 — 프론트엔드: API 서비스
|
||||
|
||||
**파일**: `frontend/src/tabs/prediction/services/predictionApi.ts`
|
||||
|
||||
```typescript
|
||||
interface ImageAnalyzeResult {
|
||||
acdntSn: number;
|
||||
lat: number;
|
||||
lon: number;
|
||||
oilType: string;
|
||||
area: number;
|
||||
volume: number;
|
||||
fileId: string;
|
||||
}
|
||||
|
||||
export const analyzeImage = async (
|
||||
file: File
|
||||
): Promise<ImageAnalyzeResult> => {
|
||||
const formData = new FormData();
|
||||
formData.append('image', file);
|
||||
const { data } = await api.post<ImageAnalyzeResult>(
|
||||
'/prediction/image-analyze',
|
||||
formData,
|
||||
{ headers: { 'Content-Type': 'multipart/form-data' }, timeout: 330000 }
|
||||
);
|
||||
return data;
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 4 — 프론트엔드: Props 타입 확장
|
||||
|
||||
**파일**: `frontend/src/tabs/prediction/components/leftPanelTypes.ts`
|
||||
|
||||
```typescript
|
||||
// 기존 Props에 추가
|
||||
onImageAnalysisResult?: (result: ImageAnalyzeResult) => void;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 5 — 프론트엔드: PredictionInputSection 수정
|
||||
|
||||
**파일**: `frontend/src/tabs/prediction/components/PredictionInputSection.tsx`
|
||||
|
||||
#### 변경 사항
|
||||
|
||||
1. **"이미지 분석 실행" 버튼** (이미지 선택 후 활성화)
|
||||
- 클릭 시 `analyzeImage(file)` 호출 (camTy는 백엔드에서 처리)
|
||||
- 로딩 스피너 표시 (분석 소요시간 수십 초~수 분)
|
||||
|
||||
2. **분석 결과 표시** (성공 시)
|
||||
- "분석 완료: 위도 XX.XXXX / 경도 XXX.XXXX / 유종: OO" 요약 메시지
|
||||
|
||||
3. **`onImageAnalysisResult` 콜백 호출**
|
||||
- 분석 성공 시 부모로 결과 전달
|
||||
|
||||
4. **에러 처리**
|
||||
- GPS_NOT_FOUND: "GPS 정보가 없는 이미지입니다" 메시지 표시
|
||||
- 타임아웃: "분석 서버 응답 없음" 메시지 표시
|
||||
|
||||
5. **로컬 상태 교체**: `uploadedImage` (Base64 DataURL) 제거, `uploadedFile: File | null`로 교체
|
||||
|
||||
---
|
||||
|
||||
### Step 6 — 프론트엔드: OilSpillView 결과 처리
|
||||
|
||||
**파일**: `frontend/src/tabs/prediction/components/OilSpillView.tsx`
|
||||
|
||||
```typescript
|
||||
const handleImageAnalysisResult = useCallback((result: ImageAnalyzeResult) => {
|
||||
// 1. 사고 좌표 자동 채우기
|
||||
setIncidentCoord({ lat: result.lat, lon: result.lon })
|
||||
setFlyToCoord({ lat: result.lat, lon: result.lon })
|
||||
|
||||
// 2. 유종/유출량 자동 채우기
|
||||
setOilType(result.oilType)
|
||||
setSpillAmount(parseFloat(result.volume.toFixed(4)))
|
||||
setSpillUnit('m³')
|
||||
|
||||
// 3. 분석 선택 상태 갱신 (acdntSn 연결 — 시뮬레이션 실행 시 기존 사고 사용)
|
||||
setSelectedAnalysis({
|
||||
acdntSn: result.acdntSn,
|
||||
acdntNm: '',
|
||||
// ... 나머지 기본값
|
||||
})
|
||||
}, [])
|
||||
```
|
||||
|
||||
`LeftPanel`에 `onImageAnalysisResult={handleImageAnalysisResult}` 전달.
|
||||
|
||||
---
|
||||
|
||||
## 수정 대상 파일 요약
|
||||
|
||||
| 파일 | 변경 유형 |
|
||||
|------|---------|
|
||||
| `database/migration/017_spil_img_rslt.sql` | **신규** — SPIL_DATA 컬럼 추가 |
|
||||
| `backend/src/prediction/imageAnalyzeService.ts` | **신규** — 이미지 분석 서비스 |
|
||||
| `backend/src/prediction/predictionRouter.ts` | **수정** — 라우트 추가 |
|
||||
| `backend/.env` | **수정** — IMAGE_API_URL 추가 |
|
||||
| `frontend/src/tabs/prediction/services/predictionApi.ts` | **수정** — analyzeImage 함수 추가 |
|
||||
| `frontend/src/tabs/prediction/components/leftPanelTypes.ts` | **수정** — Props 타입 추가 |
|
||||
| `frontend/src/tabs/prediction/components/PredictionInputSection.tsx` | **수정** — 분석 실행 UI |
|
||||
| `frontend/src/tabs/prediction/components/OilSpillView.tsx` | **수정** — 결과 처리 핸들러 |
|
||||
|
||||
---
|
||||
|
||||
## 검증 방법
|
||||
|
||||
1. **이미지 분석 서버 직접 테스트**
|
||||
```bash
|
||||
curl -X POST http://localhost:5001/run-script/ \
|
||||
-F "camTy=mx15hdi" -F "fileId=test001" -F "image=@drone_image.jpg"
|
||||
```
|
||||
|
||||
2. **백엔드 엔드포인트 테스트**
|
||||
```bash
|
||||
curl -X POST http://localhost:3001/api/prediction/image-analyze \
|
||||
-F "image=@drone_image.jpg" \
|
||||
-H "Cookie: <auth_cookie>"
|
||||
```
|
||||
- 응답: `{ acdntSn, lat, lon, oilType, area, volume, fileId }`
|
||||
- DB 확인: ACDNT, SPIL_DATA 레코드 생성 여부
|
||||
|
||||
3. **프론트엔드 E2E 테스트**
|
||||
- 이미지 업로드 모드 선택 → GPS EXIF 있는 이미지 업로드 → "이미지 분석 실행" 클릭
|
||||
- 로딩 표시 → 완료 시: 지도 이동, 유종/좌표 폼 자동 채워짐 확인
|
||||
- 나머지 필드(예상시각·유출시간 등) 직접 입력 후 "확산예측 실행" → 정상 시뮬레이션 확인
|
||||
|
||||
4. **에러 케이스 확인**
|
||||
- GPS 없는 이미지 → "GPS 정보가 없는 이미지입니다" 메시지
|
||||
@ -1,131 +0,0 @@
|
||||
import os, mmcv, cv2, json
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from mmseg.apis import init_segmentor, inference_segmentor
|
||||
from shapely.geometry import Polygon, mapping
|
||||
import sys
|
||||
|
||||
_DETECT_DIR = Path(__file__).parent # mx15hdi/Detect/
|
||||
_MX15HDI_DIR = _DETECT_DIR.parent # mx15hdi/
|
||||
|
||||
|
||||
def load_model():
|
||||
"""서버 시작 시 1회 호출. 로드된 모델 객체를 반환한다."""
|
||||
# 우선순위: 환경변수 DEVICE > GPU 자동감지 > CPU 폴백
|
||||
env_device = os.environ.get('DEVICE', '').strip()
|
||||
if env_device:
|
||||
device = env_device
|
||||
elif torch.cuda.is_available():
|
||||
device = 'cuda:0'
|
||||
else:
|
||||
device = 'cpu'
|
||||
print(f'[Inference] 사용 device: {device}')
|
||||
|
||||
config = str(_DETECT_DIR / 'V7_SPECIAL.py')
|
||||
checkpoint = str(_DETECT_DIR / 'epoch_165.pth')
|
||||
model = init_segmentor(config, checkpoint, device=device)
|
||||
model.PALETTE = [
|
||||
[0, 0, 0], # background
|
||||
[0, 0, 204], # black
|
||||
[180, 180, 180], # brown
|
||||
[255, 255, 0], # rainbow
|
||||
[178, 102, 255] # silver
|
||||
]
|
||||
return model
|
||||
|
||||
|
||||
def blend_images(original_img, color_mask, alpha=0.6):
|
||||
"""
|
||||
Blend original image and color mask with alpha transparency.
|
||||
Inputs: numpy arrays HWC uint8
|
||||
"""
|
||||
blended = cv2.addWeighted(original_img, 1 - alpha, color_mask, alpha, 0)
|
||||
return blended
|
||||
|
||||
|
||||
def run_inference(model, file_id: str, write_files: bool = False) -> dict:
|
||||
"""
|
||||
사전 로드된 모델로 file_id 폴더 내 이미지를 세그멘테이션한다.
|
||||
|
||||
Args:
|
||||
model: load_model()로 로드된 모델 객체.
|
||||
file_id: 처리할 이미지 폴더명.
|
||||
write_files: True이면 Detect/result/ 와 Detect/Mask_result/ 에 중간 파일 저장.
|
||||
False이면 디스크 쓰기 생략 (기본값).
|
||||
|
||||
Returns:
|
||||
inference_cache: {image_filename: {'blended': ndarray, 'mask': ndarray, 'ext': str}}
|
||||
이 값을 run_georeference()에 전달하면 중간 파일 읽기를 생략할 수 있다.
|
||||
"""
|
||||
img_path = str(_MX15HDI_DIR / 'Metadata' / 'Image' / 'Original_Images' / file_id)
|
||||
output_folder = str(_DETECT_DIR / 'result' / file_id)
|
||||
mask_folder = str(_DETECT_DIR / 'Mask_result' / file_id)
|
||||
|
||||
if not os.path.exists(img_path):
|
||||
raise FileNotFoundError(f"이미지 폴더가 존재하지 않습니다: {img_path}")
|
||||
|
||||
if write_files:
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
os.makedirs(mask_folder, exist_ok=True)
|
||||
|
||||
image_files = [
|
||||
f for f in os.listdir(img_path)
|
||||
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
|
||||
]
|
||||
|
||||
# palette_array는 이미지마다 동일하므로 루프 외부에서 1회 생성
|
||||
palette_array = np.array(model.PALETTE, dtype=np.uint8)
|
||||
|
||||
inference_cache = {}
|
||||
|
||||
for image_file in tqdm(image_files, desc="Processing images"):
|
||||
image_path = os.path.join(img_path, image_file)
|
||||
image_name, image_ext = os.path.splitext(image_file)
|
||||
image_ext = image_ext.lower()
|
||||
|
||||
# 이미지를 1회만 읽어 inference와 blending 모두에 재사용
|
||||
img_bgr = cv2.imread(image_path)
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 이미 로드된 배열을 inference_segmentor에 전달 (경로 전달 시 내부에서 재읽기 발생)
|
||||
result = inference_segmentor(model, img_bgr)
|
||||
seg_map = result[0]
|
||||
|
||||
# Create color mask from palette
|
||||
color_mask = palette_array[seg_map]
|
||||
|
||||
# blended image
|
||||
blended = blend_images(img_rgb, color_mask, alpha=0.6)
|
||||
blended_bgr = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# mask — numpy 슬라이싱으로 cv2.cvtColor 호출 1회 제거
|
||||
mask_bgr = color_mask[:, :, ::-1].copy()
|
||||
|
||||
# 결과를 메모리 캐시에 저장 (georeference 단계에서 재사용)
|
||||
# mask는 palette 원본(RGB) 그대로 저장 — Oilshape의 class_colors가 RGB 기준이므로 BGR로 저장 시 색상 매칭 실패
|
||||
inference_cache[image_file] = {
|
||||
'blended': blended_bgr,
|
||||
'mask': color_mask,
|
||||
'ext': image_ext,
|
||||
}
|
||||
|
||||
if write_files:
|
||||
cv2.imwrite(
|
||||
os.path.join(output_folder, f"{image_name}{image_ext}"),
|
||||
blended_bgr,
|
||||
[cv2.IMWRITE_JPEG_QUALITY, 85]
|
||||
)
|
||||
cv2.imwrite(os.path.join(mask_folder, f"{image_name}{image_ext}"), mask_bgr)
|
||||
|
||||
return inference_cache
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("파라미터가 제공되지 않았습니다. 폴더 이름을 명령줄 인자로 입력해주세요.")
|
||||
_model = load_model()
|
||||
# CLI 단독 실행 시에는 중간 파일도 디스크에 저장
|
||||
run_inference(_model, sys.argv[1], write_files=True)
|
||||
@ -1,196 +0,0 @@
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet101_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=101,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='DAHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pam_channels=64,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=5,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=5,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
dataset_type = 'CustomDataset'
|
||||
data_root = 'data/my_dataset_v7'
|
||||
img_norm_cfg = dict(
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(512, 512)),
|
||||
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True),
|
||||
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='CustomDataset',
|
||||
data_root='data/my_dataset_v7',
|
||||
img_dir='img_dir/train',
|
||||
ann_dir='ann_dir/train',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(512, 512)),
|
||||
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True),
|
||||
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]),
|
||||
val=dict(
|
||||
type='CustomDataset',
|
||||
data_root='data/my_dataset_v7',
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]),
|
||||
test=dict(
|
||||
type='CustomDataset',
|
||||
data_root='data/my_dataset_v7',
|
||||
img_dir='img_dir/test',
|
||||
ann_dir='ann_dir/test',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(
|
||||
type='Normalize',
|
||||
mean=[119.54541993, 107.13545011, 96.71320316],
|
||||
std=[60.3273945, 56.33692515, 55.71005772],
|
||||
to_rgb=True),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
],
|
||||
split=None,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png'))
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
#workflow = [('train', 1), ('val', 1)]
|
||||
workflow = [('test', 1)]
|
||||
cudnn_benchmark = True
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=3e-05,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
pos_block=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
head=dict(lr_mult=10.0))))
|
||||
optimizer_config = dict()
|
||||
lr_config = dict(
|
||||
policy='poly',
|
||||
warmup='linear',
|
||||
warmup_iters=1500,
|
||||
warmup_ratio=1e-06,
|
||||
power=1.0,
|
||||
min_lr=0.0,
|
||||
by_epoch=False)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
||||
checkpoint_config = dict(by_epoch=True, interval=1)
|
||||
evaluation = dict(by_epoch=True, interval=1, metric='mIoU')
|
||||
log_config = dict(
|
||||
interval=1000,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(
|
||||
type='WandbLoggerHook',
|
||||
init_kwargs=dict(project='Oil_Spill', name='V7_SPECIAL'))
|
||||
])
|
||||
auto_resume = False
|
||||
work_dir = 'work_dirs/V7_SPECIAL'
|
||||
gpu_ids = [0]
|
||||
@ -1,161 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
docker:
|
||||
- image: cimg/python:3.7.4
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-add-repository ppa:brightbox/ruby-ng -y
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ruby2.7
|
||||
- run:
|
||||
name: Install pre-commit hook
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
- run:
|
||||
name: Linting
|
||||
command: pre-commit run --all-files
|
||||
- run:
|
||||
name: Check docstring coverage
|
||||
command: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 50 mmseg
|
||||
|
||||
build_cpu:
|
||||
parameters:
|
||||
# The python version must match available image tags in
|
||||
# https://circleci.com/developer/images/image/cimg/python
|
||||
python:
|
||||
type: string
|
||||
default: "3.7.4"
|
||||
torch:
|
||||
type: string
|
||||
torchvision:
|
||||
type: string
|
||||
docker:
|
||||
- image: cimg/python:<< parameters.python >>
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Libraries
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx libjpeg-dev zlib1g-dev libtinfo-dev libncurses5
|
||||
- run:
|
||||
name: Configure Python & pip
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install wheel
|
||||
- run:
|
||||
name: Install PyTorch
|
||||
command: |
|
||||
python -V
|
||||
python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- run:
|
||||
name: Install mmseg dependencies
|
||||
command: |
|
||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
|
||||
python -m pip install mmdet
|
||||
python -m pip install -r requirements.txt
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
python -m pip install -e .
|
||||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
python -m pip install timm
|
||||
python -m coverage run --branch --source mmseg -m pytest tests/
|
||||
python -m coverage xml
|
||||
python -m coverage report -m
|
||||
|
||||
build_cu101:
|
||||
machine:
|
||||
image: ubuntu-1604-cuda-10.1:201909-23
|
||||
resource_class: gpu.nvidia.small
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Libraries
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx
|
||||
- run:
|
||||
name: Configure Python & pip
|
||||
command: |
|
||||
pyenv global 3.7.0
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install wheel
|
||||
- run:
|
||||
name: Install PyTorch
|
||||
command: |
|
||||
python -V
|
||||
python -m pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- run:
|
||||
name: Install mmseg dependencies
|
||||
# python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
|
||||
command: |
|
||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
|
||||
python -m pip install mmdet
|
||||
python -m pip install -r requirements.txt
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
python setup.py check -m -s
|
||||
TORCH_CUDA_ARCH_LIST=7.0 python -m pip install -e .
|
||||
- run:
|
||||
name: Run unittests
|
||||
command: |
|
||||
python -m pip install timm
|
||||
python -m pytest tests/
|
||||
|
||||
workflows:
|
||||
unit_tests:
|
||||
jobs:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.6
|
||||
torch: 1.6.0
|
||||
torchvision: 0.7.0
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.7
|
||||
torch: 1.7.0
|
||||
torchvision: 0.8.1
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.8_py3.9
|
||||
torch: 1.8.0
|
||||
torchvision: 0.9.0
|
||||
python: "3.9.0"
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.9_py3.8
|
||||
torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
python: "3.8.0"
|
||||
requires:
|
||||
- lint
|
||||
- build_cpu:
|
||||
name: build_cpu_th1.9_py3.9
|
||||
torch: 1.9.0
|
||||
torchvision: 0.10.0
|
||||
python: "3.9.0"
|
||||
requires:
|
||||
- lint
|
||||
- build_cu101:
|
||||
requires:
|
||||
- build_cpu_th1.6
|
||||
- build_cpu_th1.7
|
||||
- build_cpu_th1.8_py3.9
|
||||
- build_cpu_th1.9_py3.8
|
||||
- build_cpu_th1.9_py3.9
|
||||
@ -1,133 +0,0 @@
|
||||
# yapf: disable
|
||||
# Inference Speed is tested on NVIDIA V100
|
||||
hrnet = [
|
||||
dict(
|
||||
config='configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py',
|
||||
checkpoint='fcn_hr18s_512x512_160k_ade20k_20200614_214413-870f65ac.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=33.0),
|
||||
),
|
||||
dict(
|
||||
config='configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py',
|
||||
checkpoint='fcn_hr18s_512x1024_160k_cityscapes_20200602_190901-4a0797ea.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=76.31),
|
||||
),
|
||||
dict(
|
||||
config='configs/hrnet/fcn_hr48_512x512_160k_ade20k.py',
|
||||
checkpoint='fcn_hr48_512x512_160k_ade20k_20200614_214407-a52fc02c.pth',
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=42.02),
|
||||
),
|
||||
dict(
|
||||
config='configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py',
|
||||
checkpoint='fcn_hr48_512x1024_160k_cityscapes_20200602_190946-59b7973e.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=80.65),
|
||||
),
|
||||
]
|
||||
pspnet = [
|
||||
dict(
|
||||
config='configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py',
|
||||
checkpoint='pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=78.55),
|
||||
),
|
||||
dict(
|
||||
config='configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py',
|
||||
checkpoint='pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=79.76),
|
||||
),
|
||||
dict(
|
||||
config='configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py',
|
||||
checkpoint='pspnet_r101-d8_512x512_160k_ade20k_20200615_100650-967c316f.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=44.39),
|
||||
),
|
||||
dict(
|
||||
config='configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py',
|
||||
checkpoint='pspnet_r50-d8_512x512_160k_ade20k_20200615_184358-1890b0bd.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=42.48),
|
||||
),
|
||||
]
|
||||
resnest = [
|
||||
dict(
|
||||
config='configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py',
|
||||
checkpoint='pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=45.44),
|
||||
),
|
||||
dict(
|
||||
config='configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py',
|
||||
checkpoint='pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=78.57),
|
||||
),
|
||||
]
|
||||
fastscnn = [
|
||||
dict(
|
||||
config='configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py',
|
||||
checkpoint='fast_scnn_8x4_160k_lr0.12_cityscapes-0cec9937.pth',
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=70.96),
|
||||
)
|
||||
]
|
||||
deeplabv3plus = [
|
||||
dict(
|
||||
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py', # noqa
|
||||
checkpoint='deeplabv3plus_r101-d8_769x769_80k_cityscapes_20200607_000405-a7573d20.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=80.98),
|
||||
),
|
||||
dict(
|
||||
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py', # noqa
|
||||
checkpoint='deeplabv3plus_r101-d8_512x1024_80k_cityscapes_20200606_114143-068fcfe9.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=80.97),
|
||||
),
|
||||
dict(
|
||||
config='configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py', # noqa
|
||||
checkpoint='deeplabv3plus_r50-d8_512x1024_80k_cityscapes_20200606_114049-f9fb496d.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=80.09),
|
||||
),
|
||||
dict(
|
||||
config='configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py', # noqa
|
||||
checkpoint='deeplabv3plus_r50-d8_769x769_80k_cityscapes_20200606_210233-0e9dfdc4.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=79.83),
|
||||
),
|
||||
]
|
||||
vit = [
|
||||
dict(
|
||||
config='configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py',
|
||||
checkpoint='upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth',
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=47.73),
|
||||
),
|
||||
dict(
|
||||
config='configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py',
|
||||
checkpoint='upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth',
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=43.52),
|
||||
),
|
||||
]
|
||||
fp16 = [
|
||||
dict(
|
||||
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py', # noqa
|
||||
checkpoint='deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes_20200717_230920-f1104f4b.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=80.46),
|
||||
)
|
||||
]
|
||||
swin = [
|
||||
dict(
|
||||
config='configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py', # noqa
|
||||
checkpoint='upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth', # noqa
|
||||
eval='mIoU',
|
||||
metric=dict(mIoU=44.41),
|
||||
)
|
||||
]
|
||||
# yapf: enable
|
||||
@ -1,19 +0,0 @@
|
||||
configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py
|
||||
configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py
|
||||
configs/hrnet/fcn_hr48_512x512_160k_ade20k.py
|
||||
configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py
|
||||
configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py
|
||||
configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py
|
||||
configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py
|
||||
configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py
|
||||
configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py
|
||||
configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py
|
||||
configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py
|
||||
configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py
|
||||
configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py
|
||||
configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py
|
||||
configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py
|
||||
configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py
|
||||
configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py
|
||||
configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py
|
||||
configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py
|
||||
@ -1,41 +0,0 @@
|
||||
PARTITION=$1
|
||||
CHECKPOINT_DIR=$2
|
||||
|
||||
echo 'configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr18s_512x512_160k_ade20k configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py $CHECKPOINT_DIR/fcn_hr18s_512x512_160k_ade20k_20200614_214413-870f65ac.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr18s_512x512_160k_ade20k --cfg-options dist_params.port=28171 &
|
||||
echo 'configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr18s_512x1024_160k_cityscapes configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py $CHECKPOINT_DIR/fcn_hr18s_512x1024_160k_cityscapes_20200602_190901-4a0797ea.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr18s_512x1024_160k_cityscapes --cfg-options dist_params.port=28172 &
|
||||
echo 'configs/hrnet/fcn_hr48_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr48_512x512_160k_ade20k configs/hrnet/fcn_hr48_512x512_160k_ade20k.py $CHECKPOINT_DIR/fcn_hr48_512x512_160k_ade20k_20200614_214407-a52fc02c.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr48_512x512_160k_ade20k --cfg-options dist_params.port=28173 &
|
||||
echo 'configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr48_512x1024_160k_cityscapes configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py $CHECKPOINT_DIR/fcn_hr48_512x1024_160k_cityscapes_20200602_190946-59b7973e.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr48_512x1024_160k_cityscapes --cfg-options dist_params.port=28174 &
|
||||
echo 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r50-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r50-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28175 &
|
||||
echo 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r101-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28176 &
|
||||
echo 'configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r101-d8_512x512_160k_ade20k configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_r101-d8_512x512_160k_ade20k_20200615_100650-967c316f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r101-d8_512x512_160k_ade20k --cfg-options dist_params.port=28177 &
|
||||
echo 'configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r50-d8_512x512_160k_ade20k configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_r50-d8_512x512_160k_ade20k_20200615_184358-1890b0bd.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r50-d8_512x512_160k_ade20k --cfg-options dist_params.port=28178 &
|
||||
echo 'configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_s101-d8_512x512_160k_ade20k configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_s101-d8_512x512_160k_ade20k --cfg-options dist_params.port=28179 &
|
||||
echo 'configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_s101-d8_512x1024_80k_cityscapes configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_s101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28180 &
|
||||
echo 'configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fast_scnn_lr0.12_8x4_160k_cityscapes configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py $CHECKPOINT_DIR/fast_scnn_8x4_160k_lr0.12_cityscapes-0cec9937.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fast_scnn_lr0.12_8x4_160k_cityscapes --cfg-options dist_params.port=28181 &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_769x769_80k_cityscapes_20200607_000405-a7573d20.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_769x769_80k_cityscapes --cfg-options dist_params.port=28182 &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_512x1024_80k_cityscapes_20200606_114143-068fcfe9.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28183 &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r50-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r50-d8_512x1024_80k_cityscapes_20200606_114049-f9fb496d.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r50-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28184 &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r50-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r50-d8_769x769_80k_cityscapes_20200606_210233-0e9dfdc4.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r50-d8_769x769_80k_cityscapes --cfg-options dist_params.port=28185 &
|
||||
echo 'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_vit-b16_ln_mln_512x512_160k_ade20k configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py $CHECKPOINT_DIR/upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_vit-b16_ln_mln_512x512_160k_ade20k --cfg-options dist_params.port=28186 &
|
||||
echo 'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_deit-s16_ln_mln_512x512_160k_ade20k configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py $CHECKPOINT_DIR/upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_deit-s16_ln_mln_512x512_160k_ade20k --cfg-options dist_params.port=28187 &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes-cc58bc8d.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes --cfg-options dist_params.port=28188 &
|
||||
echo 'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py $CHECKPOINT_DIR/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K --cfg-options dist_params.port=28189 &
|
||||
@ -1,149 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import requests
|
||||
from mmcv import Config
|
||||
|
||||
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
# ignore warnings when segmentors inference
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
|
||||
"""Download checkpoint and check if hash code is true."""
|
||||
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
|
||||
|
||||
r = requests.get(url)
|
||||
assert r.status_code != 403, f'{url} Access denied.'
|
||||
|
||||
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
|
||||
code.write(r.content)
|
||||
|
||||
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
|
||||
|
||||
# check hash code
|
||||
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
|
||||
sha256_cal = hashlib.sha256()
|
||||
sha256_cal.update(fp.read())
|
||||
cur_hash_code = sha256_cal.hexdigest()[:8]
|
||||
|
||||
assert true_hash_code == cur_hash_code, f'{url} download failed, '
|
||||
'incomplete downloaded file or url invalid.'
|
||||
|
||||
if cur_hash_code != true_hash_code:
|
||||
os.remove(osp.join(collect_dir, checkpoint_name))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
|
||||
parser.add_argument(
|
||||
'-i', '--img', default='demo/demo.png', help='Image file')
|
||||
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
|
||||
parser.add_argument('-m', '--model-name', help='model name to inference')
|
||||
parser.add_argument(
|
||||
'-s', '--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'-d', '--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def inference_model(config_name, checkpoint, args, logger=None):
|
||||
cfg = Config.fromfile(config_name)
|
||||
if args.aug:
|
||||
if 'flip' in cfg.data.test.pipeline[
|
||||
1] and 'img_scale' in cfg.data.test.pipeline[1]:
|
||||
cfg.data.test.pipeline[1].img_ratios = [
|
||||
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
|
||||
]
|
||||
cfg.data.test.pipeline[1].flip = True
|
||||
else:
|
||||
if logger is not None:
|
||||
logger.error(f'{config_name}: unable to start aug test')
|
||||
else:
|
||||
print(f'{config_name}: unable to start aug test', flush=True)
|
||||
|
||||
model = init_segmentor(cfg, checkpoint, device=args.device)
|
||||
# test a single image
|
||||
result = inference_segmentor(model, args.img)
|
||||
|
||||
# show the results
|
||||
if args.show:
|
||||
show_result_pyplot(model, args.img, result)
|
||||
return result
|
||||
|
||||
|
||||
# Sample test whether the inference code is correct
|
||||
def main(args):
|
||||
config = Config.fromfile(args.config)
|
||||
|
||||
if not os.path.exists(args.checkpoint_root):
|
||||
os.makedirs(args.checkpoint_root, 0o775)
|
||||
|
||||
# test single model
|
||||
if args.model_name:
|
||||
if args.model_name in config:
|
||||
model_infos = config[args.model_name]
|
||||
if not isinstance(model_infos, list):
|
||||
model_infos = [model_infos]
|
||||
for model_info in model_infos:
|
||||
config_name = model_info['config'].strip()
|
||||
print(f'processing: {config_name}', flush=True)
|
||||
checkpoint = osp.join(args.checkpoint_root,
|
||||
model_info['checkpoint'].strip())
|
||||
try:
|
||||
# build the model from a config file and a checkpoint file
|
||||
inference_model(config_name, checkpoint, args)
|
||||
except Exception:
|
||||
print(f'{config_name} test failed!')
|
||||
continue
|
||||
return
|
||||
else:
|
||||
raise RuntimeError('model name input error.')
|
||||
|
||||
# test all model
|
||||
logger = get_root_logger(
|
||||
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
|
||||
|
||||
for model_name in config:
|
||||
model_infos = config[model_name]
|
||||
|
||||
if not isinstance(model_infos, list):
|
||||
model_infos = [model_infos]
|
||||
for model_info in model_infos:
|
||||
print('processing: ', model_info['config'], flush=True)
|
||||
config_path = model_info['config'].strip()
|
||||
config_name = osp.splitext(osp.basename(config_path))[0]
|
||||
checkpoint_name = model_info['checkpoint'].strip()
|
||||
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
|
||||
|
||||
# ensure checkpoint exists
|
||||
try:
|
||||
if not osp.exists(checkpoint):
|
||||
download_checkpoint(checkpoint_name, model_name,
|
||||
config_name.rstrip('.py'),
|
||||
args.checkpoint_root)
|
||||
except Exception:
|
||||
logger.error(f'{checkpoint_name} download error')
|
||||
continue
|
||||
|
||||
# test model inference with checkpoint
|
||||
try:
|
||||
# build the model from a config file and a checkpoint file
|
||||
inference_model(config_path, checkpoint, args, logger)
|
||||
except Exception as e:
|
||||
logger.error(f'{config_path} " : {repr(e)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -1,40 +0,0 @@
|
||||
PARTITION=$1
|
||||
|
||||
echo 'configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr18s_512x512_160k_ade20k configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24727 --work-dir work_dirs/hrnet/fcn_hr18s_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr18s_512x1024_160k_cityscapes configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24728 --work-dir work_dirs/hrnet/fcn_hr18s_512x1024_160k_cityscapes >/dev/null &
|
||||
echo 'configs/hrnet/fcn_hr48_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr48_512x512_160k_ade20k configs/hrnet/fcn_hr48_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24729 --work-dir work_dirs/hrnet/fcn_hr48_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr48_512x1024_160k_cityscapes configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24730 --work-dir work_dirs/hrnet/fcn_hr48_512x1024_160k_cityscapes >/dev/null &
|
||||
echo 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r50-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24731 --work-dir work_dirs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes >/dev/null &
|
||||
echo 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r101-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24732 --work-dir work_dirs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes >/dev/null &
|
||||
echo 'configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r101-d8_512x512_160k_ade20k configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24733 --work-dir work_dirs/pspnet/pspnet_r101-d8_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r50-d8_512x512_160k_ade20k configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24734 --work-dir work_dirs/pspnet/pspnet_r50-d8_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_s101-d8_512x512_160k_ade20k configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24735 --work-dir work_dirs/resnest/pspnet_s101-d8_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_s101-d8_512x1024_80k_cityscapes configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24736 --work-dir work_dirs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes >/dev/null &
|
||||
echo 'configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fast_scnn_lr0.12_8x4_160k_cityscapes configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24737 --work-dir work_dirs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes >/dev/null &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24738 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes >/dev/null &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24739 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes >/dev/null &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r50-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24740 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes >/dev/null &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r50-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24741 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes >/dev/null &
|
||||
echo 'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py' &
|
||||
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_vit-b16_ln_mln_512x512_160k_ade20k configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24742 --work-dir work_dirs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py' &
|
||||
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_deit-s16_ln_mln_512x512_160k_ade20k configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24743 --work-dir work_dirs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k >/dev/null &
|
||||
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py' &
|
||||
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24744 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes >/dev/null &
|
||||
echo 'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py' &
|
||||
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24745 --work-dir work_dirs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K >/dev/null &
|
||||
@ -1,101 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import requests
|
||||
import yaml as yml
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
def check_url(url):
|
||||
"""Check url response status.
|
||||
|
||||
Args:
|
||||
url (str): url needed to check.
|
||||
|
||||
Returns:
|
||||
int, bool: status code and check flag.
|
||||
"""
|
||||
flag = True
|
||||
r = requests.head(url)
|
||||
status_code = r.status_code
|
||||
if status_code == 403 or status_code == 404:
|
||||
flag = False
|
||||
|
||||
return status_code, flag
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser('url valid check.')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
'--model-name',
|
||||
type=str,
|
||||
help='Select the model needed to check')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
model_name = args.model_name
|
||||
|
||||
# yml path generate.
|
||||
# If model_name is not set, script will check all of the models.
|
||||
if model_name is not None:
|
||||
yml_list = [(model_name, f'configs/{model_name}/{model_name}.yml')]
|
||||
else:
|
||||
# check all
|
||||
yml_list = [(x, f'configs/{x}/{x}.yml') for x in os.listdir('configs/')
|
||||
if x != '_base_']
|
||||
|
||||
logger = get_root_logger(log_file='url_check.log', log_level=logging.ERROR)
|
||||
|
||||
for model_name, yml_path in yml_list:
|
||||
# Default yaml loader unsafe.
|
||||
model_infos = yml.load(
|
||||
open(yml_path, 'r'), Loader=yml.CLoader)['Models']
|
||||
for model_info in model_infos:
|
||||
config_name = model_info['Name']
|
||||
checkpoint_url = model_info['Weights']
|
||||
# checkpoint url check
|
||||
status_code, flag = check_url(checkpoint_url)
|
||||
if flag:
|
||||
logger.info(f'checkpoint | {config_name} | {checkpoint_url} | '
|
||||
f'{status_code} valid')
|
||||
else:
|
||||
logger.error(
|
||||
f'checkpoint | {config_name} | {checkpoint_url} | '
|
||||
f'{status_code} | error')
|
||||
# log_json check
|
||||
checkpoint_name = checkpoint_url.split('/')[-1]
|
||||
model_time = '-'.join(checkpoint_name.split('-')[:-1]).replace(
|
||||
f'{config_name}_', '')
|
||||
# two style of log_json name
|
||||
# use '_' to link model_time (will be deprecated)
|
||||
log_json_url_1 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json' # noqa
|
||||
status_code_1, flag_1 = check_url(log_json_url_1)
|
||||
# use '-' to link model_time
|
||||
log_json_url_2 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json' # noqa
|
||||
status_code_2, flag_2 = check_url(log_json_url_2)
|
||||
if flag_1 or flag_2:
|
||||
if flag_1:
|
||||
logger.info(
|
||||
f'log.json | {config_name} | {log_json_url_1} | '
|
||||
f'{status_code_1} | valid')
|
||||
else:
|
||||
logger.info(
|
||||
f'log.json | {config_name} | {log_json_url_2} | '
|
||||
f'{status_code_2} | valid')
|
||||
else:
|
||||
logger.error(
|
||||
f'log.json | {config_name} | {log_json_url_1} & '
|
||||
f'{log_json_url_2} | {status_code_1} & {status_code_2} | '
|
||||
'error')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,91 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
from mmcv import Config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Gather benchmarked model evaluation results')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument(
|
||||
'root',
|
||||
type=str,
|
||||
help='root path of benchmarked models to be gathered')
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
default='benchmark_evaluation_info.json',
|
||||
help='output path of gathered metrics and compared '
|
||||
'results to be stored')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
root_path = args.root
|
||||
metrics_out = args.out
|
||||
result_dict = {}
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
for model_key in cfg:
|
||||
model_infos = cfg[model_key]
|
||||
if not isinstance(model_infos, list):
|
||||
model_infos = [model_infos]
|
||||
for model_info in model_infos:
|
||||
previous_metrics = model_info['metric']
|
||||
config = model_info['config'].strip()
|
||||
fname, _ = osp.splitext(osp.basename(config))
|
||||
|
||||
# Load benchmark evaluation json
|
||||
metric_json_dir = osp.join(root_path, fname)
|
||||
if not osp.exists(metric_json_dir):
|
||||
print(f'{metric_json_dir} not existed.')
|
||||
continue
|
||||
|
||||
json_list = glob.glob(osp.join(metric_json_dir, '*.json'))
|
||||
if len(json_list) == 0:
|
||||
print(f'There is no eval json in {metric_json_dir}.')
|
||||
continue
|
||||
|
||||
log_json_path = list(sorted(json_list))[-1]
|
||||
metric = mmcv.load(log_json_path)
|
||||
if config not in metric.get('config', {}):
|
||||
print(f'{config} not included in {log_json_path}')
|
||||
continue
|
||||
|
||||
# Compare between new benchmark results and previous metrics
|
||||
differential_results = dict()
|
||||
new_metrics = dict()
|
||||
for record_metric_key in previous_metrics:
|
||||
if record_metric_key not in metric['metric']:
|
||||
raise KeyError('record_metric_key not exist, please '
|
||||
'check your config')
|
||||
old_metric = previous_metrics[record_metric_key]
|
||||
new_metric = round(metric['metric'][record_metric_key] * 100,
|
||||
2)
|
||||
|
||||
differential = new_metric - old_metric
|
||||
flag = '+' if differential > 0 else '-'
|
||||
differential_results[
|
||||
record_metric_key] = f'{flag}{abs(differential):.2f}'
|
||||
new_metrics[record_metric_key] = new_metric
|
||||
|
||||
result_dict[config] = dict(
|
||||
differential=differential_results,
|
||||
previous=previous_metrics,
|
||||
new=new_metrics)
|
||||
|
||||
if metrics_out:
|
||||
mmcv.dump(result_dict, metrics_out, indent=4)
|
||||
print('===================================')
|
||||
for config_name, metrics in result_dict.items():
|
||||
print(config_name, metrics)
|
||||
print('===================================')
|
||||
@ -1,100 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
from gather_models import get_final_results
|
||||
from mmcv import Config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Gather benchmarked models train results')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument(
|
||||
'root',
|
||||
type=str,
|
||||
help='root path of benchmarked models to be gathered')
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
default='benchmark_train_info.json',
|
||||
help='output path of gathered metrics to be stored')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
root_path = args.root
|
||||
metrics_out = args.out
|
||||
|
||||
evaluation_cfg = Config.fromfile(args.config)
|
||||
|
||||
result_dict = {}
|
||||
for model_key in evaluation_cfg:
|
||||
model_infos = evaluation_cfg[model_key]
|
||||
if not isinstance(model_infos, list):
|
||||
model_infos = [model_infos]
|
||||
for model_info in model_infos:
|
||||
config = model_info['config']
|
||||
|
||||
# benchmark train dir
|
||||
model_name = osp.split(osp.dirname(config))[1]
|
||||
config_name = osp.splitext(osp.basename(config))[0]
|
||||
exp_dir = osp.join(root_path, model_name, config_name)
|
||||
if not osp.exists(exp_dir):
|
||||
print(f'{config} hasn\'t {exp_dir}')
|
||||
continue
|
||||
|
||||
# parse config
|
||||
cfg = mmcv.Config.fromfile(config)
|
||||
total_iters = cfg.runner.max_iters
|
||||
exp_metric = cfg.evaluation.metric
|
||||
if not isinstance(exp_metric, list):
|
||||
exp_metrics = [exp_metric]
|
||||
|
||||
# determine whether total_iters ckpt exists
|
||||
ckpt_path = f'iter_{total_iters}.pth'
|
||||
if not osp.exists(osp.join(exp_dir, ckpt_path)):
|
||||
print(f'{config} hasn\'t {ckpt_path}')
|
||||
continue
|
||||
|
||||
# only the last log json counts
|
||||
log_json_path = list(
|
||||
sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1]
|
||||
|
||||
# extract metric value
|
||||
model_performance = get_final_results(log_json_path, total_iters)
|
||||
if model_performance is None:
|
||||
print(f'log file error: {log_json_path}')
|
||||
continue
|
||||
|
||||
differential_results = dict()
|
||||
old_results = dict()
|
||||
new_results = dict()
|
||||
for metric_key in model_performance:
|
||||
if metric_key in ['mIoU']:
|
||||
metric = round(model_performance[metric_key] * 100, 2)
|
||||
old_metric = model_info['metric'][metric_key]
|
||||
old_results[metric_key] = old_metric
|
||||
new_results[metric_key] = metric
|
||||
differential = metric - old_metric
|
||||
flag = '+' if differential > 0 else '-'
|
||||
differential_results[
|
||||
metric_key] = f'{flag}{abs(differential):.2f}'
|
||||
result_dict[config] = dict(
|
||||
differential_results=differential_results,
|
||||
old_results=old_results,
|
||||
new_results=new_results,
|
||||
)
|
||||
|
||||
# 4 save or print results
|
||||
if metrics_out:
|
||||
mmcv.dump(result_dict, metrics_out, indent=4)
|
||||
print('===================================')
|
||||
for config_name, metrics in result_dict.items():
|
||||
print(config_name, metrics)
|
||||
print('===================================')
|
||||
@ -1,211 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
||||
# build schedule look-up table to automatically find the final model
|
||||
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
|
||||
|
||||
|
||||
def calculate_file_sha256(file_path):
|
||||
"""calculate file sha256 hash code."""
|
||||
with open(file_path, 'rb') as fp:
|
||||
sha256_cal = hashlib.sha256()
|
||||
sha256_cal.update(fp.read())
|
||||
return sha256_cal.hexdigest()
|
||||
|
||||
|
||||
def process_checkpoint(in_file, out_file):
|
||||
checkpoint = torch.load(in_file, map_location='cpu')
|
||||
# remove optimizer for smaller file size
|
||||
if 'optimizer' in checkpoint:
|
||||
del checkpoint['optimizer']
|
||||
# if it is necessary to remove some sensitive data in checkpoint['meta'],
|
||||
# add the code here.
|
||||
torch.save(checkpoint, out_file)
|
||||
# The hash code calculation and rename command differ on different system
|
||||
# platform.
|
||||
sha = calculate_file_sha256(out_file)
|
||||
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
|
||||
os.rename(out_file, final_file)
|
||||
|
||||
# Remove prefix and suffix
|
||||
final_file_name = osp.split(final_file)[1]
|
||||
final_file_name = osp.splitext(final_file_name)[0]
|
||||
|
||||
return final_file_name
|
||||
|
||||
|
||||
def get_final_iter(config):
|
||||
iter_num = config.split('_')[-2]
|
||||
assert iter_num.endswith('k')
|
||||
return int(iter_num[:-1]) * 1000
|
||||
|
||||
|
||||
def get_final_results(log_json_path, iter_num):
|
||||
result_dict = dict()
|
||||
last_iter = 0
|
||||
with open(log_json_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
log_line = json.loads(line)
|
||||
if 'mode' not in log_line.keys():
|
||||
continue
|
||||
|
||||
# When evaluation, the 'iter' of new log json is the evaluation
|
||||
# steps on single gpu.
|
||||
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
|
||||
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
|
||||
if flag1 and flag2:
|
||||
result_dict.update({
|
||||
key: log_line[key]
|
||||
for key in RESULTS_LUT if key in log_line
|
||||
})
|
||||
return result_dict
|
||||
|
||||
last_iter = log_line['iter']
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Gather benchmarked models')
|
||||
parser.add_argument(
|
||||
'-f', '--config-name', type=str, help='Process the selected config.')
|
||||
parser.add_argument(
|
||||
'-w',
|
||||
'--work-dir',
|
||||
default='work_dirs/',
|
||||
type=str,
|
||||
help='Ckpt storage root folder of benchmarked models to be gathered.')
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--collect-dir',
|
||||
default='work_dirs/gather',
|
||||
type=str,
|
||||
help='Ckpt collect root folder of gathered models.')
|
||||
parser.add_argument(
|
||||
'--all', action='store_true', help='whether include .py and .log')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
work_dir = args.work_dir
|
||||
collect_dir = args.collect_dir
|
||||
selected_config_name = args.config_name
|
||||
mmcv.mkdir_or_exist(collect_dir)
|
||||
|
||||
# find all models in the root directory to be gathered
|
||||
raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True))
|
||||
|
||||
# filter configs that is not trained in the experiments dir
|
||||
used_configs = []
|
||||
for raw_config in raw_configs:
|
||||
config_name = osp.splitext(osp.basename(raw_config))[0]
|
||||
if osp.exists(osp.join(work_dir, config_name)):
|
||||
if (selected_config_name is None
|
||||
or selected_config_name == config_name):
|
||||
used_configs.append(raw_config)
|
||||
print(f'Find {len(used_configs)} models to be gathered')
|
||||
|
||||
# find final_ckpt and log file for trained each config
|
||||
# and parse the best performance
|
||||
model_infos = []
|
||||
for used_config in used_configs:
|
||||
config_name = osp.splitext(osp.basename(used_config))[0]
|
||||
exp_dir = osp.join(work_dir, config_name)
|
||||
# check whether the exps is finished
|
||||
final_iter = get_final_iter(used_config)
|
||||
final_model = 'iter_{}.pth'.format(final_iter)
|
||||
model_path = osp.join(exp_dir, final_model)
|
||||
|
||||
# skip if the model is still training
|
||||
if not osp.exists(model_path):
|
||||
print(f'{used_config} train not finished yet')
|
||||
continue
|
||||
|
||||
# get logs
|
||||
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
|
||||
log_json_path = log_json_paths[0]
|
||||
model_performance = None
|
||||
for idx, _log_json_path in enumerate(log_json_paths):
|
||||
model_performance = get_final_results(_log_json_path, final_iter)
|
||||
if model_performance is not None:
|
||||
log_json_path = _log_json_path
|
||||
break
|
||||
|
||||
if model_performance is None:
|
||||
print(f'{used_config} model_performance is None')
|
||||
continue
|
||||
|
||||
model_time = osp.split(log_json_path)[-1].split('.')[0]
|
||||
model_infos.append(
|
||||
dict(
|
||||
config_name=config_name,
|
||||
results=model_performance,
|
||||
iters=final_iter,
|
||||
model_time=model_time,
|
||||
log_json_path=osp.split(log_json_path)[-1]))
|
||||
|
||||
# publish model for each checkpoint
|
||||
publish_model_infos = []
|
||||
for model in model_infos:
|
||||
config_name = model['config_name']
|
||||
model_publish_dir = osp.join(collect_dir, config_name)
|
||||
|
||||
publish_model_path = osp.join(model_publish_dir,
|
||||
config_name + '_' + model['model_time'])
|
||||
trained_model_path = osp.join(work_dir, config_name,
|
||||
'iter_{}.pth'.format(model['iters']))
|
||||
if osp.exists(model_publish_dir):
|
||||
for file in os.listdir(model_publish_dir):
|
||||
if file.endswith('.pth'):
|
||||
print(f'model {file} found')
|
||||
model['model_path'] = osp.abspath(
|
||||
osp.join(model_publish_dir, file))
|
||||
break
|
||||
if 'model_path' not in model:
|
||||
print(f'dir {model_publish_dir} exists, no model found')
|
||||
|
||||
else:
|
||||
mmcv.mkdir_or_exist(model_publish_dir)
|
||||
|
||||
# convert model
|
||||
final_model_path = process_checkpoint(trained_model_path,
|
||||
publish_model_path)
|
||||
model['model_path'] = final_model_path
|
||||
|
||||
new_json_path = f'{config_name}_{model["log_json_path"]}'
|
||||
# copy log
|
||||
shutil.copy(
|
||||
osp.join(work_dir, config_name, model['log_json_path']),
|
||||
osp.join(model_publish_dir, new_json_path))
|
||||
|
||||
if args.all:
|
||||
new_txt_path = new_json_path.rstrip('.json')
|
||||
shutil.copy(
|
||||
osp.join(work_dir, config_name,
|
||||
model['log_json_path'].rstrip('.json')),
|
||||
osp.join(model_publish_dir, new_txt_path))
|
||||
|
||||
if args.all:
|
||||
# copy config to guarantee reproducibility
|
||||
raw_config = osp.join('./configs', f'{config_name}.py')
|
||||
mmcv.Config.fromfile(raw_config).dump(
|
||||
osp.join(model_publish_dir, osp.basename(raw_config)))
|
||||
|
||||
publish_model_infos.append(model)
|
||||
|
||||
models = dict(models=publish_model_infos)
|
||||
mmcv.dump(models, osp.join(collect_dir, 'model_infos.json'), indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,114 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
from mmcv import Config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert benchmark test model list to script')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
parser.add_argument('--port', type=int, default=28171, help='dist port')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
default='work_dirs/benchmark_evaluation',
|
||||
help='the dir to save metric')
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
default='.dev/benchmark_evaluation.sh',
|
||||
help='path to save model benchmark script')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def process_model_info(model_info, work_dir):
|
||||
config = model_info['config'].strip()
|
||||
fname, _ = osp.splitext(osp.basename(config))
|
||||
job_name = fname
|
||||
checkpoint = model_info['checkpoint'].strip()
|
||||
work_dir = osp.join(work_dir, fname)
|
||||
if not isinstance(model_info['eval'], list):
|
||||
evals = [model_info['eval']]
|
||||
else:
|
||||
evals = model_info['eval']
|
||||
eval = ' '.join(evals)
|
||||
return dict(
|
||||
config=config,
|
||||
job_name=job_name,
|
||||
checkpoint=checkpoint,
|
||||
work_dir=work_dir,
|
||||
eval=eval)
|
||||
|
||||
|
||||
def create_test_bash_info(commands, model_test_dict, port, script_name,
|
||||
partition):
|
||||
config = model_test_dict['config']
|
||||
job_name = model_test_dict['job_name']
|
||||
checkpoint = model_test_dict['checkpoint']
|
||||
work_dir = model_test_dict['work_dir']
|
||||
eval = model_test_dict['eval']
|
||||
|
||||
echo_info = f'\necho \'{config}\' &'
|
||||
commands.append(echo_info)
|
||||
commands.append('\n')
|
||||
|
||||
command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \
|
||||
f'CPUS_PER_TASK=2 {script_name} '
|
||||
|
||||
command_info += f'{partition} '
|
||||
command_info += f'{job_name} '
|
||||
command_info += f'{config} '
|
||||
command_info += f'$CHECKPOINT_DIR/{checkpoint} '
|
||||
|
||||
command_info += f'--eval {eval} '
|
||||
command_info += f'--work-dir {work_dir} '
|
||||
command_info += f'--cfg-options dist_params.port={port} '
|
||||
command_info += '&'
|
||||
|
||||
commands.append(command_info)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.out:
|
||||
out_suffix = args.out.split('.')[-1]
|
||||
assert args.out.endswith('.sh'), \
|
||||
f'Expected out file path suffix is .sh, but get .{out_suffix}'
|
||||
|
||||
commands = []
|
||||
partition_name = 'PARTITION=$1'
|
||||
commands.append(partition_name)
|
||||
commands.append('\n')
|
||||
|
||||
checkpoint_root = 'CHECKPOINT_DIR=$2'
|
||||
commands.append(checkpoint_root)
|
||||
commands.append('\n')
|
||||
|
||||
script_name = osp.join('tools', 'slurm_test.sh')
|
||||
port = args.port
|
||||
work_dir = args.work_dir
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
for model_key in cfg:
|
||||
model_infos = cfg[model_key]
|
||||
if not isinstance(model_infos, list):
|
||||
model_infos = [model_infos]
|
||||
for model_info in model_infos:
|
||||
print('processing: ', model_info['config'])
|
||||
model_test_dict = process_model_info(model_info, work_dir)
|
||||
create_test_bash_info(commands, model_test_dict, port, script_name,
|
||||
'$PARTITION')
|
||||
port += 1
|
||||
|
||||
command_str = ''.join(commands)
|
||||
if args.out:
|
||||
with open(args.out, 'w') as f:
|
||||
f.write(command_str + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,91 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
# Default using 4 gpu when training
|
||||
config_8gpu_list = [
|
||||
'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py', # noqa
|
||||
'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py',
|
||||
'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py',
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert benchmark model json to script')
|
||||
parser.add_argument(
|
||||
'txt_path', type=str, help='txt path output by benchmark_filter')
|
||||
parser.add_argument('--port', type=int, default=24727, help='dist port')
|
||||
parser.add_argument(
|
||||
'--out',
|
||||
type=str,
|
||||
default='.dev/benchmark_train.sh',
|
||||
help='path to save model benchmark script')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def create_train_bash_info(commands, config, script_name, partition, port):
|
||||
cfg = config.strip()
|
||||
|
||||
# print cfg name
|
||||
echo_info = f'echo \'{cfg}\' &'
|
||||
commands.append(echo_info)
|
||||
commands.append('\n')
|
||||
|
||||
_, model_name = osp.split(osp.dirname(cfg))
|
||||
config_name, _ = osp.splitext(osp.basename(cfg))
|
||||
# default setting
|
||||
if cfg in config_8gpu_list:
|
||||
command_info = f'GPUS=8 GPUS_PER_NODE=8 ' \
|
||||
f'CPUS_PER_TASK=2 {script_name} '
|
||||
else:
|
||||
command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \
|
||||
f'CPUS_PER_TASK=2 {script_name} '
|
||||
command_info += f'{partition} '
|
||||
command_info += f'{config_name} '
|
||||
command_info += f'{cfg} '
|
||||
command_info += f'--cfg-options ' \
|
||||
f'checkpoint_config.max_keep_ckpts=1 ' \
|
||||
f'dist_params.port={port} '
|
||||
command_info += f'--work-dir work_dirs/{model_name}/{config_name} '
|
||||
# Let the script shut up
|
||||
command_info += '>/dev/null &'
|
||||
|
||||
commands.append(command_info)
|
||||
commands.append('\n')
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.out:
|
||||
out_suffix = args.out.split('.')[-1]
|
||||
assert args.out.endswith('.sh'), \
|
||||
f'Expected out file path suffix is .sh, but get .{out_suffix}'
|
||||
|
||||
root_name = './tools'
|
||||
script_name = osp.join(root_name, 'slurm_train.sh')
|
||||
port = args.port
|
||||
partition_name = 'PARTITION=$1'
|
||||
|
||||
commands = []
|
||||
commands.append(partition_name)
|
||||
commands.append('\n')
|
||||
commands.append('\n')
|
||||
|
||||
with open(args.txt_path, 'r') as f:
|
||||
model_cfgs = f.readlines()
|
||||
for i, cfg in enumerate(model_cfgs):
|
||||
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
|
||||
port)
|
||||
port += 1
|
||||
|
||||
command_str = ''.join(commands)
|
||||
if args.out:
|
||||
with open(args.out, 'w') as f:
|
||||
f.write(command_str)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,18 +0,0 @@
|
||||
work_dir = '../../work_dirs'
|
||||
metric = 'mIoU'
|
||||
|
||||
# specify the log files we would like to collect in `log_items`
|
||||
log_items = [
|
||||
'segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr'
|
||||
]
|
||||
# or specify ignore_keywords, then the folders whose name contain
|
||||
# `'segformer'` won't be collected
|
||||
# ignore_keywords = ['segformer']
|
||||
|
||||
# should not include metric
|
||||
other_info_keys = ['mAcc']
|
||||
markdown_file = 'markdowns/lr_in_trans.json.md'
|
||||
json_file = 'jsons/trans_in_cnn.json'
|
||||
@ -1,143 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
from utils import load_config
|
||||
|
||||
# automatically collect all the results
|
||||
|
||||
# The structure of the directory:
|
||||
# ├── work-dir
|
||||
# │ ├── config_1
|
||||
# │ │ ├── time1.log.json
|
||||
# │ │ ├── time2.log.json
|
||||
# │ │ ├── time3.log.json
|
||||
# │ │ ├── time4.log.json
|
||||
# │ ├── config_2
|
||||
# │ │ ├── time5.log.json
|
||||
# │ │ ├── time6.log.json
|
||||
# │ │ ├── time7.log.json
|
||||
# │ │ ├── time8.log.json
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='extract info from log.json')
|
||||
parser.add_argument('config_dir')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def has_keyword(name: str, keywords: list):
|
||||
for a_keyword in keywords:
|
||||
if a_keyword in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = load_config(args.config_dir)
|
||||
work_dir = cfg['work_dir']
|
||||
metric = cfg['metric']
|
||||
log_items = cfg.get('log_items', [])
|
||||
ignore_keywords = cfg.get('ignore_keywords', [])
|
||||
other_info_keys = cfg.get('other_info_keys', [])
|
||||
markdown_file = cfg.get('markdown_file', None)
|
||||
json_file = cfg.get('json_file', None)
|
||||
|
||||
if json_file and osp.split(json_file)[0] != '':
|
||||
os.makedirs(osp.split(json_file)[0], exist_ok=True)
|
||||
if markdown_file and osp.split(markdown_file)[0] != '':
|
||||
os.makedirs(osp.split(markdown_file)[0], exist_ok=True)
|
||||
|
||||
assert not (log_items and ignore_keywords), \
|
||||
'log_items and ignore_keywords cannot be specified at the same time'
|
||||
assert metric not in other_info_keys, \
|
||||
'other_info_keys should not contain metric'
|
||||
|
||||
if ignore_keywords and isinstance(ignore_keywords, str):
|
||||
ignore_keywords = [ignore_keywords]
|
||||
if other_info_keys and isinstance(other_info_keys, str):
|
||||
other_info_keys = [other_info_keys]
|
||||
if log_items and isinstance(log_items, str):
|
||||
log_items = [log_items]
|
||||
|
||||
if not log_items:
|
||||
log_items = [
|
||||
item for item in sorted(os.listdir(work_dir))
|
||||
if not has_keyword(item, ignore_keywords)
|
||||
]
|
||||
|
||||
experiment_info_list = []
|
||||
for config_dir in log_items:
|
||||
preceding_path = os.path.join(work_dir, config_dir)
|
||||
log_list = [
|
||||
item for item in os.listdir(preceding_path)
|
||||
if item.endswith('.log.json')
|
||||
]
|
||||
log_list = sorted(
|
||||
log_list,
|
||||
key=lambda time_str: datetime.datetime.strptime(
|
||||
time_str, '%Y%m%d_%H%M%S.log.json'))
|
||||
val_list = []
|
||||
last_iter = 0
|
||||
for log_name in log_list:
|
||||
with open(os.path.join(preceding_path, log_name), 'r') as f:
|
||||
# ignore the info line
|
||||
f.readline()
|
||||
all_lines = f.readlines()
|
||||
val_list.extend([
|
||||
json.loads(line) for line in all_lines
|
||||
if json.loads(line)['mode'] == 'val'
|
||||
])
|
||||
for index in range(len(all_lines) - 1, -1, -1):
|
||||
line_dict = json.loads(all_lines[index])
|
||||
if line_dict['mode'] == 'train':
|
||||
last_iter = max(last_iter, line_dict['iter'])
|
||||
break
|
||||
|
||||
new_log_dict = dict(
|
||||
method=config_dir, metric_used=metric, last_iter=last_iter)
|
||||
for index, log in enumerate(val_list, 1):
|
||||
new_ordered_dict = OrderedDict()
|
||||
new_ordered_dict['eval_index'] = index
|
||||
new_ordered_dict[metric] = log[metric]
|
||||
for key in other_info_keys:
|
||||
if key in log:
|
||||
new_ordered_dict[key] = log[key]
|
||||
val_list[index - 1] = new_ordered_dict
|
||||
|
||||
assert len(val_list) >= 1, \
|
||||
f"work dir {config_dir} doesn't contain any evaluation."
|
||||
new_log_dict['last eval'] = val_list[-1]
|
||||
new_log_dict['best eval'] = max(val_list, key=lambda x: x[metric])
|
||||
experiment_info_list.append(new_log_dict)
|
||||
print(f'{config_dir} is processed')
|
||||
|
||||
if json_file:
|
||||
with open(json_file, 'w') as f:
|
||||
json.dump(experiment_info_list, f, indent=4)
|
||||
|
||||
if markdown_file:
|
||||
lines_to_write = []
|
||||
for index, log in enumerate(experiment_info_list, 1):
|
||||
lines_to_write.append(
|
||||
f"|{index}|{log['method']}|{log['best eval'][metric]}"
|
||||
f"|{log['best eval']['eval_index']}|"
|
||||
f"{log['last eval'][metric]}|"
|
||||
f"{log['last eval']['eval_index']}|{log['last_iter']}|\n")
|
||||
with open(markdown_file, 'w') as f:
|
||||
f.write(f'|exp_num|method|{metric} best|best index|'
|
||||
f'{metric} last|last index|last iter num|\n')
|
||||
f.write('|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n')
|
||||
f.writelines(lines_to_write)
|
||||
|
||||
print('processed successfully')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,144 +0,0 @@
|
||||
# Log Collector
|
||||
|
||||
## Function
|
||||
|
||||
Automatically collect logs and write the result in a json file or markdown file.
|
||||
|
||||
If there are several `.log.json` files in one folder, Log Collector assumes that the `.log.json` files other than the first one are resume from the preceding `.log.json` file. Log Collector returns the result considering all `.log.json` files.
|
||||
|
||||
## Usage:
|
||||
|
||||
To use log collector, you need to write a config file to configure the log collector first.
|
||||
|
||||
For example:
|
||||
|
||||
example_config.py:
|
||||
|
||||
```python
|
||||
# The work directory that contains folders that contains .log.json files.
|
||||
work_dir = '../../work_dirs'
|
||||
# The metric used to find the best evaluation.
|
||||
metric = 'mIoU'
|
||||
|
||||
# **Don't specify the log_items and ignore_keywords at the same time.**
|
||||
# Specify the log files we would like to collect in `log_items`.
|
||||
# The folders specified should be the subdirectories of `work_dir`.
|
||||
log_items = [
|
||||
'segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr',
|
||||
'segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr'
|
||||
]
|
||||
# Or specify `ignore_keywords`. The folders whose name contain one
|
||||
# of the keywords in the `ignore_keywords` list(e.g., `'segformer'`)
|
||||
# won't be collected.
|
||||
# ignore_keywords = ['segformer']
|
||||
|
||||
# Other log items in .log.json that you want to collect.
|
||||
# should not include metric.
|
||||
other_info_keys = ["mAcc"]
|
||||
# The output markdown file's name.
|
||||
markdown_file ='markdowns/lr_in_trans.json.md'
|
||||
# The output json file's name. (optional)
|
||||
json_file = 'jsons/trans_in_cnn.json'
|
||||
```
|
||||
|
||||
The structure of the work-dir directory should be like:
|
||||
|
||||
```text
|
||||
├── work-dir
|
||||
│ ├── folder1
|
||||
│ │ ├── time1.log.json
|
||||
│ │ ├── time2.log.json
|
||||
│ │ ├── time3.log.json
|
||||
│ │ ├── time4.log.json
|
||||
│ ├── folder2
|
||||
│ │ ├── time5.log.json
|
||||
│ │ ├── time6.log.json
|
||||
│ │ ├── time7.log.json
|
||||
│ │ ├── time8.log.json
|
||||
```
|
||||
|
||||
Then , cd to the log collector folder.
|
||||
|
||||
Now you can run log_collector.py by using command:
|
||||
|
||||
```bash
|
||||
python log_collector.py ./example_config.py
|
||||
```
|
||||
|
||||
The output markdown file is like:
|
||||
|
||||
| exp_num | method | mIoU best | best index | mIoU last | last index | last iter num |
|
||||
| :-----: | :-----------------------------------------------------: | :-------: | :--------: | :-------: | :--------: | :-----------: |
|
||||
| 1 | segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup | 0.2776 | 10 | 0.2776 | 10 | 160000 |
|
||||
| 2 | segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr | 0.2802 | 10 | 0.2802 | 10 | 160000 |
|
||||
| 3 | segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr | 0.4943 | 11 | 0.4943 | 11 | 160000 |
|
||||
| 4 | segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr | 0.4883 | 11 | 0.4883 | 11 | 160000 |
|
||||
|
||||
The output json file is like:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"method": "segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup",
|
||||
"metric_used": "mIoU",
|
||||
"last_iter": 160000,
|
||||
"last eval": {
|
||||
"eval_index": 10,
|
||||
"mIoU": 0.2776,
|
||||
"mAcc": 0.3779
|
||||
},
|
||||
"best eval": {
|
||||
"eval_index": 10,
|
||||
"mIoU": 0.2776,
|
||||
"mAcc": 0.3779
|
||||
}
|
||||
},
|
||||
{
|
||||
"method": "segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr",
|
||||
"metric_used": "mIoU",
|
||||
"last_iter": 160000,
|
||||
"last eval": {
|
||||
"eval_index": 10,
|
||||
"mIoU": 0.2802,
|
||||
"mAcc": 0.3764
|
||||
},
|
||||
"best eval": {
|
||||
"eval_index": 10,
|
||||
"mIoU": 0.2802,
|
||||
"mAcc": 0.3764
|
||||
}
|
||||
},
|
||||
{
|
||||
"method": "segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr",
|
||||
"metric_used": "mIoU",
|
||||
"last_iter": 160000,
|
||||
"last eval": {
|
||||
"eval_index": 11,
|
||||
"mIoU": 0.4943,
|
||||
"mAcc": 0.6097
|
||||
},
|
||||
"best eval": {
|
||||
"eval_index": 11,
|
||||
"mIoU": 0.4943,
|
||||
"mAcc": 0.6097
|
||||
}
|
||||
},
|
||||
{
|
||||
"method": "segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr",
|
||||
"metric_used": "mIoU",
|
||||
"last_iter": 160000,
|
||||
"last eval": {
|
||||
"eval_index": 11,
|
||||
"mIoU": 0.4883,
|
||||
"mAcc": 0.6061
|
||||
},
|
||||
"best eval": {
|
||||
"eval_index": 11,
|
||||
"mIoU": 0.4883,
|
||||
"mAcc": 0.6061
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
@ -1,20 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# modified from https://github.dev/open-mmlab/mmcv
|
||||
import os.path as osp
|
||||
import sys
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def load_config(cfg_dir: str) -> dict:
|
||||
assert cfg_dir.endswith('.py')
|
||||
root_path, file_name = osp.split(cfg_dir)
|
||||
temp_module = osp.splitext(file_name)[0]
|
||||
sys.path.insert(0, root_path)
|
||||
mod = import_module(temp_module)
|
||||
sys.path.pop(0)
|
||||
cfg_dict = {
|
||||
k: v
|
||||
for k, v in mod.__dict__.items() if not k.startswith('__')
|
||||
}
|
||||
del sys.modules[temp_module]
|
||||
return cfg_dict
|
||||
@ -1,317 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This tool is used to update model-index.yml which is required by MIM, and
|
||||
# will be automatically called as a pre-commit hook. The updating will be
|
||||
# triggered if any change of model information (.md files in configs/) has been
|
||||
# detected before a commit.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import sys
|
||||
|
||||
from lxml import etree
|
||||
from mmcv.fileio import dump
|
||||
|
||||
MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__))))
|
||||
|
||||
COLLECTIONS = [
|
||||
'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet',
|
||||
'DeepLabV3', 'DeepLabV3+', 'DMNet', 'DNLNet', 'DPT', 'EMANet', 'EncNet',
|
||||
'ERFNet', 'FastFCN', 'FastSCNN', 'FCN', 'GCNet', 'ICNet', 'ISANet', 'KNet',
|
||||
'NonLocalNet', 'OCRNet', 'PointRend', 'PSANet', 'PSPNet', 'Segformer',
|
||||
'Segmenter', 'FPN', 'SETR', 'STDC', 'UNet', 'UPerNet'
|
||||
]
|
||||
COLLECTIONS_TEMP = []
|
||||
|
||||
|
||||
def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
|
||||
"""Dump object to a yaml file, and check if the file content is different
|
||||
from the original.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
filename (str): YAML filename to dump the object to.
|
||||
sort_keys (str); Sort key by dictionary order.
|
||||
Returns:
|
||||
Bool: If the target YAML file is different from the original.
|
||||
"""
|
||||
|
||||
str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys)
|
||||
if osp.isfile(filename):
|
||||
file_exists = True
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
str_orig = f.read()
|
||||
else:
|
||||
file_exists = False
|
||||
str_orig = None
|
||||
|
||||
if file_exists and str_orig == str_dump:
|
||||
is_different = False
|
||||
else:
|
||||
is_different = True
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
f.write(str_dump)
|
||||
|
||||
return is_different
|
||||
|
||||
|
||||
def parse_md(md_file):
|
||||
"""Parse .md file and convert it to a .yml file which can be used for MIM.
|
||||
|
||||
Args:
|
||||
md_file (str): Path to .md file.
|
||||
Returns:
|
||||
Bool: If the target YAML file is different from the original.
|
||||
"""
|
||||
collection_name = osp.split(osp.dirname(md_file))[1]
|
||||
configs = os.listdir(osp.dirname(md_file))
|
||||
|
||||
collection = dict(
|
||||
Name=collection_name,
|
||||
Metadata={'Training Data': []},
|
||||
Paper={
|
||||
'URL': '',
|
||||
'Title': ''
|
||||
},
|
||||
README=md_file,
|
||||
Code={
|
||||
'URL': '',
|
||||
'Version': ''
|
||||
})
|
||||
collection.update({'Converted From': {'Weights': '', 'Code': ''}})
|
||||
models = []
|
||||
datasets = []
|
||||
paper_url = None
|
||||
paper_title = None
|
||||
code_url = None
|
||||
code_version = None
|
||||
repo_url = None
|
||||
|
||||
# To avoid re-counting number of backbone model in OpenMMLab,
|
||||
# if certain model in configs folder is backbone whose name is already
|
||||
# recorded in MMClassification, then the `COLLECTION` dict of this model
|
||||
# in MMSegmentation should be deleted, and `In Collection` in `Models`
|
||||
# should be set with head or neck of this config file.
|
||||
is_backbone = None
|
||||
|
||||
with open(md_file, 'r', encoding='UTF-8') as md:
|
||||
lines = md.readlines()
|
||||
i = 0
|
||||
current_dataset = ''
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
# In latest README.md the title and url are in the third line.
|
||||
if i == 2:
|
||||
paper_url = lines[i].split('](')[1].split(')')[0]
|
||||
paper_title = lines[i].split('](')[0].split('[')[1]
|
||||
if len(line) == 0:
|
||||
i += 1
|
||||
continue
|
||||
elif line[:3] == '<a ':
|
||||
content = etree.HTML(line)
|
||||
node = content.xpath('//a')[0]
|
||||
if node.text == 'Code Snippet':
|
||||
code_url = node.get('href', None)
|
||||
assert code_url is not None, (
|
||||
f'{collection_name} hasn\'t code snippet url.')
|
||||
# version extraction
|
||||
filter_str = r'blob/(.*)/mm'
|
||||
pattern = re.compile(filter_str)
|
||||
code_version = pattern.findall(code_url)
|
||||
assert len(code_version) == 1, (
|
||||
f'false regular expression ({filter_str}) use.')
|
||||
code_version = code_version[0]
|
||||
elif node.text == 'Official Repo':
|
||||
repo_url = node.get('href', None)
|
||||
assert repo_url is not None, (
|
||||
f'{collection_name} hasn\'t official repo url.')
|
||||
i += 1
|
||||
elif line[:4] == '### ':
|
||||
datasets.append(line[4:])
|
||||
current_dataset = line[4:]
|
||||
i += 2
|
||||
elif line[:15] == '<!-- [BACKBONE]':
|
||||
is_backbone = True
|
||||
i += 1
|
||||
elif (line[0] == '|' and (i + 1) < len(lines)
|
||||
and lines[i + 1][:3] == '| -' and 'Method' in line
|
||||
and 'Crop Size' in line and 'Mem (GB)' in line):
|
||||
cols = [col.strip() for col in line.split('|')]
|
||||
method_id = cols.index('Method')
|
||||
backbone_id = cols.index('Backbone')
|
||||
crop_size_id = cols.index('Crop Size')
|
||||
lr_schd_id = cols.index('Lr schd')
|
||||
mem_id = cols.index('Mem (GB)')
|
||||
fps_id = cols.index('Inf time (fps)')
|
||||
try:
|
||||
ss_id = cols.index('mIoU')
|
||||
except ValueError:
|
||||
ss_id = cols.index('Dice')
|
||||
try:
|
||||
ms_id = cols.index('mIoU(ms+flip)')
|
||||
except ValueError:
|
||||
ms_id = False
|
||||
config_id = cols.index('config')
|
||||
download_id = cols.index('download')
|
||||
j = i + 2
|
||||
while j < len(lines) and lines[j][0] == '|':
|
||||
els = [el.strip() for el in lines[j].split('|')]
|
||||
config = ''
|
||||
model_name = ''
|
||||
weight = ''
|
||||
for fn in configs:
|
||||
if fn in els[config_id]:
|
||||
left = els[download_id].index(
|
||||
'https://download.openmmlab.com')
|
||||
right = els[download_id].index('.pth') + 4
|
||||
weight = els[download_id][left:right]
|
||||
config = f'configs/{collection_name}/{fn}'
|
||||
model_name = fn[:-3]
|
||||
fps = els[fps_id] if els[fps_id] != '-' and els[
|
||||
fps_id] != '' else -1
|
||||
mem = els[mem_id].split(
|
||||
'\\'
|
||||
)[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
|
||||
crop_size = els[crop_size_id].split('x')
|
||||
assert len(crop_size) == 2
|
||||
method = els[method_id].split()[0].split('-')[-1]
|
||||
model = {
|
||||
'Name':
|
||||
model_name,
|
||||
'In Collection':
|
||||
method,
|
||||
'Metadata': {
|
||||
'backbone': els[backbone_id],
|
||||
'crop size': f'({crop_size[0]},{crop_size[1]})',
|
||||
'lr schd': int(els[lr_schd_id]),
|
||||
},
|
||||
'Results': [
|
||||
{
|
||||
'Task': 'Semantic Segmentation',
|
||||
'Dataset': current_dataset,
|
||||
'Metrics': {
|
||||
cols[ss_id]: float(els[ss_id]),
|
||||
},
|
||||
},
|
||||
],
|
||||
'Config':
|
||||
config,
|
||||
'Weights':
|
||||
weight,
|
||||
}
|
||||
if fps != -1:
|
||||
try:
|
||||
fps = float(fps)
|
||||
except Exception:
|
||||
j += 1
|
||||
continue
|
||||
model['Metadata']['inference time (ms/im)'] = [{
|
||||
'value':
|
||||
round(1000 / float(fps), 2),
|
||||
'hardware':
|
||||
'V100',
|
||||
'backend':
|
||||
'PyTorch',
|
||||
'batch size':
|
||||
1,
|
||||
'mode':
|
||||
'FP32' if 'fp16' not in config else 'FP16',
|
||||
'resolution':
|
||||
f'({crop_size[0]},{crop_size[1]})'
|
||||
}]
|
||||
if mem != -1:
|
||||
model['Metadata']['Training Memory (GB)'] = float(mem)
|
||||
# Only have semantic segmentation now
|
||||
if ms_id and els[ms_id] != '-' and els[ms_id] != '':
|
||||
model['Results'][0]['Metrics'][
|
||||
'mIoU(ms+flip)'] = float(els[ms_id])
|
||||
models.append(model)
|
||||
j += 1
|
||||
i = j
|
||||
else:
|
||||
i += 1
|
||||
flag = (code_url is not None) and (paper_url is not None) and (repo_url
|
||||
is not None)
|
||||
assert flag, f'{collection_name} readme error'
|
||||
collection['Name'] = method
|
||||
collection['Metadata']['Training Data'] = datasets
|
||||
collection['Code']['URL'] = code_url
|
||||
collection['Code']['Version'] = code_version
|
||||
collection['Paper']['URL'] = paper_url
|
||||
collection['Paper']['Title'] = paper_title
|
||||
collection['Converted From']['Code'] = repo_url
|
||||
# ['Converted From']['Weights] miss
|
||||
# remove empty attribute
|
||||
check_key_list = ['Code', 'Paper', 'Converted From']
|
||||
for check_key in check_key_list:
|
||||
key_list = list(collection[check_key].keys())
|
||||
for key in key_list:
|
||||
if check_key not in collection:
|
||||
break
|
||||
if collection[check_key][key] == '':
|
||||
if len(collection[check_key].keys()) == 1:
|
||||
collection.pop(check_key)
|
||||
else:
|
||||
collection[check_key].pop(key)
|
||||
yml_file = f'{md_file[:-9]}{collection_name}.yml'
|
||||
if is_backbone:
|
||||
if collection['Name'] not in COLLECTIONS:
|
||||
result = {
|
||||
'Collections': [collection],
|
||||
'Models': models,
|
||||
'Yml': yml_file
|
||||
}
|
||||
COLLECTIONS_TEMP.append(result)
|
||||
return False
|
||||
else:
|
||||
result = {'Models': models}
|
||||
else:
|
||||
COLLECTIONS.append(collection['Name'])
|
||||
result = {'Collections': [collection], 'Models': models}
|
||||
return dump_yaml_and_check_difference(result, yml_file)
|
||||
|
||||
|
||||
def update_model_index():
|
||||
"""Update model-index.yml according to model .md files.
|
||||
|
||||
Returns:
|
||||
Bool: If the updated model-index.yml is different from the original.
|
||||
"""
|
||||
configs_dir = osp.join(MMSEG_ROOT, 'configs')
|
||||
yml_files = glob.glob(osp.join(configs_dir, '**', '*.yml'), recursive=True)
|
||||
yml_files.sort()
|
||||
|
||||
# add .replace('\\', '/') to avoid Windows Style path
|
||||
model_index = {
|
||||
'Import': [
|
||||
osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
|
||||
for yml_file in yml_files
|
||||
]
|
||||
}
|
||||
model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
|
||||
is_different = dump_yaml_and_check_difference(model_index,
|
||||
model_index_file)
|
||||
|
||||
return is_different
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
file_list = [fn for fn in sys.argv[1:] if osp.basename(fn) == 'README.md']
|
||||
if not file_list:
|
||||
sys.exit(0)
|
||||
file_modified = False
|
||||
for fn in file_list:
|
||||
file_modified |= parse_md(fn)
|
||||
|
||||
for result in COLLECTIONS_TEMP:
|
||||
collection = result['Collections'][0]
|
||||
yml_file = result.pop('Yml', None)
|
||||
if collection['Name'] in COLLECTIONS:
|
||||
result.pop('Collections')
|
||||
file_modified |= dump_yaml_and_check_difference(result, yml_file)
|
||||
|
||||
file_modified |= update_model_index()
|
||||
sys.exit(1 if file_modified else 0)
|
||||
@ -1,45 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import oss2
|
||||
|
||||
ACCESS_KEY_ID = os.getenv('OSS_ACCESS_KEY_ID', None)
|
||||
ACCESS_KEY_SECRET = os.getenv('OSS_ACCESS_KEY_SECRET', None)
|
||||
BUCKET_NAME = 'openmmlab'
|
||||
ENDPOINT = 'https://oss-accelerate.aliyuncs.com'
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Upload models to OSS')
|
||||
parser.add_argument('model_zoo', type=str, help='model_zoo input')
|
||||
parser.add_argument(
|
||||
'--dst-folder',
|
||||
type=str,
|
||||
default='mmsegmentation/v0.5',
|
||||
help='destination folder')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
model_zoo = args.model_zoo
|
||||
dst_folder = args.dst_folder
|
||||
bucket = oss2.Bucket(
|
||||
oss2.Auth(ACCESS_KEY_ID, ACCESS_KEY_SECRET), ENDPOINT, BUCKET_NAME)
|
||||
|
||||
for root, dirs, files in os.walk(model_zoo):
|
||||
for file in files:
|
||||
file_path = osp.relpath(osp.join(root, file), model_zoo)
|
||||
print(f'Uploading {file_path}')
|
||||
|
||||
oss2.resumable_upload(bucket, osp.join(dst_folder, file_path),
|
||||
osp.join(model_zoo, file_path))
|
||||
bucket.put_object_acl(
|
||||
osp.join(dst_folder, file_path), oss2.OBJECT_ACL_PUBLIC_READ)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,76 +0,0 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to making participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
- Using welcoming and inclusive language
|
||||
- Being respectful of differing viewpoints and experiences
|
||||
- Gracefully accepting constructive criticism
|
||||
- Focusing on what is best for the community
|
||||
- Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
- The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
- Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
- Public or private harassment
|
||||
- Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
- Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community. Examples of
|
||||
representing a project or community include using an official project e-mail
|
||||
address, posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event. Representation of a project may be
|
||||
further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at chenkaidev@gmail.com. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
@ -1,58 +0,0 @@
|
||||
# Contributing to mmsegmentation
|
||||
|
||||
All kinds of contributions are welcome, including but not limited to the following.
|
||||
|
||||
- Fixes (typo, bugs)
|
||||
- New features and components
|
||||
|
||||
## Workflow
|
||||
|
||||
1. fork and pull the latest mmsegmentation
|
||||
2. checkout a new branch (do not use master branch for PRs)
|
||||
3. commit your changes
|
||||
4. create a PR
|
||||
|
||||
:::{note}
|
||||
|
||||
- If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
|
||||
- If you are the author of some papers and would like to include your method to mmsegmentation,
|
||||
please contact Kai Chen (chenkaidev\[at\]gmail\[dot\]com). We will much appreciate your contribution.
|
||||
:::
|
||||
|
||||
## Code style
|
||||
|
||||
### Python
|
||||
|
||||
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
|
||||
|
||||
We use the following tools for linting and formatting:
|
||||
|
||||
- [flake8](http://flake8.pycqa.org/en/latest/): linter
|
||||
- [yapf](https://github.com/google/yapf): formatter
|
||||
- [isort](https://github.com/timothycrosley/isort): sort imports
|
||||
|
||||
Style configurations of yapf and isort can be found in [setup.cfg](../setup.cfg) and [.isort.cfg](../.isort.cfg).
|
||||
|
||||
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`,
|
||||
fixes `end-of-files`, sorts `requirments.txt` automatically on every commit.
|
||||
The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).
|
||||
|
||||
After you clone the repository, you will need to install initialize pre-commit hook.
|
||||
|
||||
```shell
|
||||
pip install -U pre-commit
|
||||
```
|
||||
|
||||
From the repository folder
|
||||
|
||||
```shell
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
After this on every commit check code linters and formatter will be enforced.
|
||||
|
||||
> Before you create a PR, make sure that your code lints and is formatted by yapf.
|
||||
|
||||
### C++ and CUDA
|
||||
|
||||
We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
|
||||
@ -1,6 +0,0 @@
|
||||
blank_issues_enabled: false
|
||||
|
||||
contact_links:
|
||||
- name: MMSegmentation Documentation
|
||||
url: https://mmsegmentation.readthedocs.io
|
||||
about: Check the docs and FAQ to see if you question is already answered.
|
||||
@ -1,48 +0,0 @@
|
||||
---
|
||||
name: Error report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
Thanks for your error report and we appreciate it a lot.
|
||||
|
||||
**Checklist**
|
||||
|
||||
1. I have searched related issues but cannot get the expected help.
|
||||
2. The bug has not been fixed in the latest version.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**Reproduction**
|
||||
|
||||
1. What command or script did you run?
|
||||
|
||||
```none
|
||||
A placeholder for the command.
|
||||
```
|
||||
|
||||
2. Did you make any modifications on the code or config? Did you understand what you have modified?
|
||||
|
||||
3. What dataset did you use?
|
||||
|
||||
**Environment**
|
||||
|
||||
1. Please run `python mmseg/utils/collect_env.py` to collect necessary environment information and paste it here.
|
||||
2. You may add addition that may be helpful for locating the problem, such as
|
||||
- How you installed PyTorch \[e.g., pip, conda, source\]
|
||||
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
|
||||
|
||||
**Error traceback**
|
||||
|
||||
If applicable, paste the error trackback here.
|
||||
|
||||
```none
|
||||
A placeholder for trackback.
|
||||
```
|
||||
|
||||
**Bug fix**
|
||||
|
||||
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
|
||||
@ -1,21 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
# Describe the feature
|
||||
|
||||
**Motivation**
|
||||
A clear and concise description of the motivation of the feature.
|
||||
Ex1. It is inconvenient when \[....\].
|
||||
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
|
||||
|
||||
**Related resources**
|
||||
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
|
||||
@ -1,7 +0,0 @@
|
||||
---
|
||||
name: General questions
|
||||
about: Ask general questions to get help
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
---
|
||||
@ -1,69 +0,0 @@
|
||||
---
|
||||
name: Reimplementation Questions
|
||||
about: Ask about questions during model reimplementation
|
||||
title: ''
|
||||
labels: reimplementation
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
If you feel we have helped you, give us a STAR! :satisfied:
|
||||
|
||||
**Notice**
|
||||
|
||||
There are several common situations in the reimplementation issues as below
|
||||
|
||||
1. Reimplement a model in the model zoo using the provided configs
|
||||
2. Reimplement a model in the model zoo on other datasets (e.g., custom datasets)
|
||||
3. Reimplement a custom model but all the components are implemented in MMSegmentation
|
||||
4. Reimplement a custom model with new modules implemented by yourself
|
||||
|
||||
There are several things to do for different cases as below.
|
||||
|
||||
- For cases 1 & 3, please follow the steps in the following sections thus we could help to quickly identify the issue.
|
||||
- For cases 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code, and the users should be responsible for the code they write.
|
||||
- One suggestion for cases 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections, and try as clear as possible so that we can better help you.
|
||||
|
||||
**Checklist**
|
||||
|
||||
1. I have searched related issues but cannot get the expected help.
|
||||
2. The issue has not been fixed in the latest version.
|
||||
|
||||
**Describe the issue**
|
||||
|
||||
A clear and concise description of the problem you meet and what you have done.
|
||||
|
||||
**Reproduction**
|
||||
|
||||
1. What command or script did you run?
|
||||
|
||||
```
|
||||
A placeholder for the command.
|
||||
```
|
||||
|
||||
2. What config dir you run?
|
||||
|
||||
```
|
||||
A placeholder for the config.
|
||||
```
|
||||
|
||||
3. Did you make any modifications to the code or config? Did you understand what you have modified?
|
||||
4. What dataset did you use?
|
||||
|
||||
**Environment**
|
||||
|
||||
1. Please run `PYTHONPATH=${PWD}:$PYTHONPATH python mmseg/utils/collect_env.py` to collect the necessary environment information and paste it here.
|
||||
2. You may add an addition that may be helpful for locating the problem, such as
|
||||
1. How you installed PyTorch \[e.g., pip, conda, source\]
|
||||
2. Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
|
||||
|
||||
**Results**
|
||||
|
||||
If applicable, paste the related results here, e.g., what you expect and what you get.
|
||||
|
||||
```
|
||||
A placeholder for results comparison
|
||||
```
|
||||
|
||||
**Issue fix**
|
||||
|
||||
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
|
||||
@ -1,25 +0,0 @@
|
||||
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
|
||||
|
||||
## Motivation
|
||||
|
||||
Please describe the motivation of this PR and the goal you want to achieve through this PR.
|
||||
|
||||
## Modification
|
||||
|
||||
Please briefly describe what modification is made in this PR.
|
||||
|
||||
## BC-breaking (Optional)
|
||||
|
||||
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
|
||||
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
|
||||
|
||||
## Use cases (Optional)
|
||||
|
||||
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
|
||||
|
||||
## Checklist
|
||||
|
||||
1. Pre-commit or other linting tools are used to fix the potential lint issues.
|
||||
2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
|
||||
3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D.
|
||||
4. The documentation has been modified accordingly, like docstring or example tutorials.
|
||||
@ -1,257 +0,0 @@
|
||||
name: build
|
||||
|
||||
on:
|
||||
push:
|
||||
paths-ignore:
|
||||
- 'demo/**'
|
||||
- '.dev/**'
|
||||
- 'docker/**'
|
||||
- 'tools/**'
|
||||
- '**.md'
|
||||
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'demo/**'
|
||||
- '.dev/**'
|
||||
- 'docker/**'
|
||||
- 'tools/**'
|
||||
- 'docs/**'
|
||||
- '**.md'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_cpu:
|
||||
runs-on: ubuntu-18.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
|
||||
include:
|
||||
- torch: 1.5.1
|
||||
torch_version: torch1.5
|
||||
torchvision: 0.6.1
|
||||
- torch: 1.6.0
|
||||
torch_version: torch1.6
|
||||
torchvision: 0.7.0
|
||||
- torch: 1.7.0
|
||||
torch_version: torch1.7
|
||||
torchvision: 0.8.1
|
||||
- torch: 1.8.0
|
||||
torch_version: torch1.8
|
||||
torchvision: 0.9.0
|
||||
- torch: 1.9.0
|
||||
torch_version: torch1.9
|
||||
torchvision: 0.10.0
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch >= '1.5.0'}}
|
||||
- name: Skip timm unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch < '1.5.0'}}
|
||||
|
||||
build_cuda101:
|
||||
runs-on: ubuntu-18.04
|
||||
container:
|
||||
image: pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch:
|
||||
[
|
||||
1.5.1+cu101,
|
||||
1.6.0+cu101,
|
||||
1.7.0+cu101,
|
||||
1.8.0+cu101
|
||||
]
|
||||
include:
|
||||
- torch: 1.5.1+cu101
|
||||
torch_version: torch1.5
|
||||
torchvision: 0.6.1+cu101
|
||||
- torch: 1.6.0+cu101
|
||||
torch_version: torch1.6
|
||||
torchvision: 0.7.0+cu101
|
||||
- torch: 1.7.0+cu101
|
||||
torch_version: torch1.7
|
||||
torchvision: 0.8.1+cu101
|
||||
- torch: 1.8.0+cu101
|
||||
torch_version: torch1.8
|
||||
torchvision: 0.9.0+cu101
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Fetch GPG keys
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install -y libgl1-mesa-glx ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 python${{matrix.python-version}}-dev
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
- name: Install Pillow
|
||||
run: python -m pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision < 0.5}}
|
||||
- name: Install PyTorch
|
||||
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install mmseg dependencies
|
||||
run: |
|
||||
python -V
|
||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs
|
||||
python setup.py check -m -s
|
||||
TORCH_CUDA_ARCH_LIST=7.0 pip install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
python -m pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch >= '1.5.0'}}
|
||||
- name: Skip timm unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch < '1.5.0'}}
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v1.0.10
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
build_cuda102:
|
||||
runs-on: ubuntu-18.04
|
||||
container:
|
||||
image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: [1.9.0+cu102]
|
||||
include:
|
||||
- torch: 1.9.0+cu102
|
||||
torch_version: torch1.9
|
||||
torchvision: 0.10.0+cu102
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Fetch GPG keys
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install -y libgl1-mesa-glx ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
- name: Install Pillow
|
||||
run: python -m pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision < 0.5}}
|
||||
- name: Install PyTorch
|
||||
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install mmseg dependencies
|
||||
run: |
|
||||
python -V
|
||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/${{matrix.torch_version}}/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Build and install
|
||||
run: |
|
||||
rm -rf .eggs
|
||||
python setup.py check -m -s
|
||||
TORCH_CUDA_ARCH_LIST=7.0 pip install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
python -m pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
test_windows:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-2022]
|
||||
python: [3.8]
|
||||
platform: [cpu, cu111]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install pip --upgrade --user
|
||||
- name: Install OpenCV
|
||||
run: pip install opencv-python>=3
|
||||
- name: Install PyTorch
|
||||
# As a complement to Linux CI, we test on PyTorch LTS version
|
||||
run: pip install torch==1.8.2+${{ matrix.platform }} torchvision==0.9.2+${{ matrix.platform }} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
|
||||
- name: Install MMCV
|
||||
run: |
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8/index.html --only-binary mmcv-full
|
||||
- name: Install unittest dependencies
|
||||
run: pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||
- name: Build and install
|
||||
run: pip install -e .
|
||||
- name: Run unittests
|
||||
run: |
|
||||
python -m pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
- name: Generate coverage report
|
||||
run: |
|
||||
coverage xml
|
||||
coverage report -m
|
||||
@ -1,26 +0,0 @@
|
||||
name: deploy
|
||||
|
||||
on: push
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Build MMSegmentation
|
||||
run: |
|
||||
pip install wheel
|
||||
python setup.py sdist bdist_wheel
|
||||
- name: Publish distribution to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
|
||||
@ -1,28 +0,0 @@
|
||||
name: lint
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
- name: Install pre-commit hook
|
||||
run: |
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
- name: Linting
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
- name: Check docstring coverage
|
||||
run: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmseg/ops --ignore-regex "__repr__" --fail-under 80 mmseg
|
||||
@ -1,44 +0,0 @@
|
||||
name: test-mim
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'model-index.yml'
|
||||
- 'configs/**'
|
||||
|
||||
pull_request:
|
||||
paths:
|
||||
- 'model-index.yml'
|
||||
- 'configs/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build_cpu:
|
||||
runs-on: ubuntu-18.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.8.0]
|
||||
include:
|
||||
- torch: 1.8.0
|
||||
torch_version: torch1.8
|
||||
torchvision: 0.9.0
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install openmim
|
||||
run: pip install openmim
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && mim install -e .
|
||||
- name: test commands of mim
|
||||
run: mim search mmsegmentation
|
||||
@ -1,120 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/en/_build/
|
||||
docs/zh_cn/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.DS_Store
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
data
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# custom
|
||||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
work_dirs/
|
||||
mmseg/.mim
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
@ -1,11 +0,0 @@
|
||||
assign:
|
||||
strategy:
|
||||
# random
|
||||
# round-robin
|
||||
daily-shift-based
|
||||
assignees:
|
||||
- MengzhangLI
|
||||
- xiexinch
|
||||
- MeowZheng
|
||||
- MengzhangLI
|
||||
- xiexinch
|
||||
@ -1,60 +0,0 @@
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: double-quote-string-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: fix-encoding-pragma
|
||||
args: ["--remove"]
|
||||
- id: mixed-line-ending
|
||||
args: ["--fix=lf"]
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.9
|
||||
hooks:
|
||||
- id: mdformat
|
||||
args: ["--number"]
|
||||
additional_dependencies:
|
||||
- mdformat-openmmlab
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: update-model-index
|
||||
name: update-model-index
|
||||
description: Collect model information and update model-index.yml
|
||||
entry: .dev/md2yml.py
|
||||
additional_dependencies: [mmcv, lxml, opencv-python]
|
||||
language: python
|
||||
files: ^configs/.*\.md$
|
||||
require_serial: true
|
||||
- repo: https://github.com/open-mmlab/pre-commit-hooks
|
||||
rev: v0.2.0 # Use the rev to fix revision
|
||||
hooks:
|
||||
- id: check-algo-readme
|
||||
- id: check-copyright
|
||||
args: ["mmseg", "tools", "tests", "demo"] # the dir_to_check with expected directory to check
|
||||
@ -1,9 +0,0 @@
|
||||
version: 2
|
||||
|
||||
formats: all
|
||||
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/readthedocs.txt
|
||||
@ -1,8 +0,0 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "MMSegmentation Contributors"
|
||||
title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
|
||||
date-released: 2020-07-10
|
||||
url: "https://github.com/open-mmlab/mmsegmentation"
|
||||
license: Apache-2.0
|
||||
@ -1,203 +0,0 @@
|
||||
Copyright 2020 The MMSegmentation Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 The MMSegmentation Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -1,7 +0,0 @@
|
||||
# Licenses for special features
|
||||
|
||||
In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.
|
||||
|
||||
| Feature | Files | License |
|
||||
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
|
||||
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |
|
||||
@ -1,4 +0,0 @@
|
||||
include requirements/*.txt
|
||||
include mmseg/.mim/model-index.yml
|
||||
recursive-include mmseg/.mim/configs *.py *.yml
|
||||
recursive-include mmseg/.mim/tools *.py *.sh
|
||||
@ -1,229 +0,0 @@
|
||||
<div align="center">
|
||||
<img src="resources/mmseg-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab website</font></b>
|
||||
<sup>
|
||||
<a href="https://openmmlab.com">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
|
||||
<b><font size="5">OpenMMLab platform</font></b>
|
||||
<sup>
|
||||
<a href="https://platform.openmmlab.com">
|
||||
<i><font size="4">TRY IT OUT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
</div>
|
||||
<div> </div>
|
||||
|
||||
<br />
|
||||
|
||||
[](https://pypi.org/project/mmsegmentation/)
|
||||
[](https://pypi.org/project/mmsegmentation)
|
||||
[](https://mmsegmentation.readthedocs.io/en/latest/)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmsegmentation)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
|
||||
[📘Documentation](https://mmsegmentation.readthedocs.io/en/latest/) |
|
||||
[🛠️Installation](https://mmsegmentation.readthedocs.io/en/latest/get_started.html) |
|
||||
[👀Model Zoo](https://mmsegmentation.readthedocs.io/en/latest/model_zoo.html) |
|
||||
[🆕Update News](https://mmsegmentation.readthedocs.io/en/latest/changelog.html) |
|
||||
[🤔Reporting Issues](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
English | [简体中文](README_zh-CN.md)
|
||||
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
MMSegmentation is an open source semantic segmentation toolbox based on PyTorch.
|
||||
It is a part of the [OpenMMLab](https://openmmlab.com/) project.
|
||||
|
||||
The master branch works with **PyTorch 1.5+**.
|
||||
|
||||

|
||||
|
||||
<details open>
|
||||
<summary>Major features</summary>
|
||||
|
||||
- **Unified Benchmark**
|
||||
|
||||
We provide a unified benchmark toolbox for various semantic segmentation methods.
|
||||
|
||||
- **Modular Design**
|
||||
|
||||
We decompose the semantic segmentation framework into different components and one can easily construct a customized semantic segmentation framework by combining different modules.
|
||||
|
||||
- **Support of multiple methods out of box**
|
||||
|
||||
The toolbox directly supports popular and contemporary semantic segmentation frameworks, *e.g.* PSPNet, DeepLabV3, PSANet, DeepLabV3+, etc.
|
||||
|
||||
- **High efficiency**
|
||||
|
||||
The training speed is faster than or comparable to other codebases.
|
||||
|
||||
</details>
|
||||
|
||||
## What's New
|
||||
|
||||
v0.25.0 was released in 6/2/2022:
|
||||
|
||||
- Support PyTorch backend on MLU
|
||||
|
||||
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to [get_started.md](docs/en/get_started.md#installation) for installation and [dataset_prepare.md](docs/en/dataset_prepare.md#prepare-datasets) for dataset preparation.
|
||||
|
||||
## Get Started
|
||||
|
||||
Please see [train.md](docs/en/train.md) and [inference.md](docs/en/inference.md) for the basic usage of MMSegmentation.
|
||||
There are also tutorials for:
|
||||
|
||||
- [customizing dataset](docs/en/tutorials/customize_datasets.md)
|
||||
- [designing data pipeline](docs/en/tutorials/data_pipeline.md)
|
||||
- [customizing modules](docs/en/tutorials/customize_models.md)
|
||||
- [customizing runtime](docs/en/tutorials/customize_runtime.md)
|
||||
- [training tricks](docs/en/tutorials/training_tricks.md)
|
||||
- [useful tools](docs/en/useful_tools.md)
|
||||
|
||||
A Colab tutorial is also provided. You may preview the notebook [here](demo/MMSegmentation_Tutorial.ipynb) or directly [run](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/master/demo/MMSegmentation_Tutorial.ipynb) on Colab.
|
||||
|
||||
## Benchmark and model zoo
|
||||
|
||||
Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
||||
|
||||
Supported backbones:
|
||||
|
||||
- [x] ResNet (CVPR'2016)
|
||||
- [x] ResNeXt (CVPR'2017)
|
||||
- [x] [HRNet (CVPR'2019)](configs/hrnet)
|
||||
- [x] [ResNeSt (ArXiv'2020)](configs/resnest)
|
||||
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
|
||||
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
|
||||
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
|
||||
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
|
||||
- [x] [Twins (NeurIPS'2021)](configs/twins)
|
||||
- [x] [BEiT (ICLR'2022)](configs/beit)
|
||||
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
|
||||
- [x] [MAE (CVPR'2022)](configs/mae)
|
||||
|
||||
Supported methods:
|
||||
|
||||
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
|
||||
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
|
||||
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
|
||||
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
|
||||
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
|
||||
- [x] [BiSeNetV1 (ECCV'2018)](configs/bisenetv1)
|
||||
- [x] [PSANet (ECCV'2018)](configs/psanet)
|
||||
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
|
||||
- [x] [UPerNet (ECCV'2018)](configs/upernet)
|
||||
- [x] [ICNet (ECCV'2018)](configs/icnet)
|
||||
- [x] [NonLocal Net (CVPR'2018)](configs/nonlocal_net)
|
||||
- [x] [EncNet (CVPR'2018)](configs/encnet)
|
||||
- [x] [Semantic FPN (CVPR'2019)](configs/sem_fpn)
|
||||
- [x] [DANet (CVPR'2019)](configs/danet)
|
||||
- [x] [APCNet (CVPR'2019)](configs/apcnet)
|
||||
- [x] [EMANet (ICCV'2019)](configs/emanet)
|
||||
- [x] [CCNet (ICCV'2019)](configs/ccnet)
|
||||
- [x] [DMNet (ICCV'2019)](configs/dmnet)
|
||||
- [x] [ANN (ICCV'2019)](configs/ann)
|
||||
- [x] [GCNet (ICCVW'2019/TPAMI'2020)](configs/gcnet)
|
||||
- [x] [FastFCN (ArXiv'2019)](configs/fastfcn)
|
||||
- [x] [Fast-SCNN (ArXiv'2019)](configs/fastscnn)
|
||||
- [x] [ISANet (ArXiv'2019/IJCV'2021)](configs/isanet)
|
||||
- [x] [OCRNet (ECCV'2020)](configs/ocrnet)
|
||||
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
|
||||
- [x] [STDC (CVPR'2021)](configs/stdc)
|
||||
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||
- [x] [DPT (ArXiv'2021)](configs/dpt)
|
||||
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
|
||||
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
|
||||
- [x] [K-Net (NeurIPS'2021)](configs/knet)
|
||||
|
||||
Supported datasets:
|
||||
|
||||
- [x] [Cityscapes](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#cityscapes)
|
||||
- [x] [PASCAL VOC](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-voc)
|
||||
- [x] [ADE20K](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#ade20k)
|
||||
- [x] [Pascal Context](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-context)
|
||||
- [x] [COCO-Stuff 10k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-10k)
|
||||
- [x] [COCO-Stuff 164k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-164k)
|
||||
- [x] [CHASE_DB1](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#chase-db1)
|
||||
- [x] [DRIVE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#drive)
|
||||
- [x] [HRF](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#hrf)
|
||||
- [x] [STARE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#stare)
|
||||
- [x] [Dark Zurich](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#dark-zurich)
|
||||
- [x] [Nighttime Driving](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#nighttime-driving)
|
||||
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda)
|
||||
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam)
|
||||
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen)
|
||||
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isaid)
|
||||
|
||||
## FAQ
|
||||
|
||||
Please refer to [FAQ](docs/en/faq.md) for frequently asked questions.
|
||||
|
||||
## Contributing
|
||||
|
||||
We appreciate all contributions to improve MMSegmentation. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
MMSegmentation is an open source project that welcome any contribution and feedback.
|
||||
We wish that the toolbox and benchmark could serve the growing research
|
||||
community by providing a flexible as well as standardized toolkit to reimplement existing methods
|
||||
and develop their own new semantic segmentation methods.
|
||||
|
||||
## Citation
|
||||
|
||||
If you find this project useful in your research, please consider cite:
|
||||
|
||||
```bibtex
|
||||
@misc{mmseg2020,
|
||||
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
|
||||
author={MMSegmentation Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MMSegmentation is released under the Apache 2.0 license, while some specific features in this library are with other licenses. Please refer to [LICENSES.md](LICENSES.md) for the careful check, if you are using our code for commercial matters.
|
||||
|
||||
## Projects in OpenMMLab
|
||||
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab Model Deployment Framework.
|
||||
@ -1,242 +0,0 @@
|
||||
<div align="center">
|
||||
<img src="resources/mmseg-logo.png" width="600"/>
|
||||
<div> </div>
|
||||
<div align="center">
|
||||
<b><font size="5">OpenMMLab 官网</font></b>
|
||||
<sup>
|
||||
<a href="https://openmmlab.com">
|
||||
<i><font size="4">HOT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
|
||||
<b><font size="5">OpenMMLab 开放平台</font></b>
|
||||
<sup>
|
||||
<a href="https://platform.openmmlab.com">
|
||||
<i><font size="4">TRY IT OUT</font></i>
|
||||
</a>
|
||||
</sup>
|
||||
</div>
|
||||
<div> </div>
|
||||
|
||||
<br />
|
||||
|
||||
[](https://pypi.org/project/mmsegmentation/)
|
||||
[](https://pypi.org/project/mmsegmentation)
|
||||
[](https://mmsegmentation.readthedocs.io/zh_CN/latest/)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/actions)
|
||||
[](https://codecov.io/gh/open-mmlab/mmsegmentation)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
[](https://github.com/open-mmlab/mmsegmentation/issues)
|
||||
|
||||
[📘使用文档](https://mmsegmentation.readthedocs.io/en/latest/) |
|
||||
[🛠️安装指南](https://mmsegmentation.readthedocs.io/en/latest/get_started.html) |
|
||||
[👀模型库](https://mmsegmentation.readthedocs.io/en/latest/model_zoo.html) |
|
||||
[🆕更新日志](https://mmsegmentation.readthedocs.io/en/latest/changelog.html) |
|
||||
[🤔报告问题](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
|
||||
|
||||
[English](README.md) | 简体中文
|
||||
|
||||
</div>
|
||||
|
||||
## 简介
|
||||
|
||||
MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。
|
||||
|
||||
主分支代码目前支持 PyTorch 1.5 以上的版本。
|
||||
|
||||

|
||||
|
||||
<details open>
|
||||
<summary>Major features</summary>
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **统一的基准平台**
|
||||
|
||||
我们将各种各样的语义分割算法集成到了一个统一的工具箱,进行基准测试。
|
||||
|
||||
- **模块化设计**
|
||||
|
||||
MMSegmentation 将分割框架解耦成不同的模块组件,通过组合不同的模块组件,用户可以便捷地构建自定义的分割模型。
|
||||
|
||||
- **丰富的即插即用的算法和模型**
|
||||
|
||||
MMSegmentation 支持了众多主流的和最新的检测算法,例如 PSPNet,DeepLabV3,PSANet,DeepLabV3+ 等.
|
||||
|
||||
- **速度快**
|
||||
|
||||
训练速度比其他语义分割代码库更快或者相当。
|
||||
|
||||
</details>
|
||||
|
||||
## 最新进展
|
||||
|
||||
最新版本 v0.25.0 在 2022.6.2 发布:
|
||||
|
||||
- 支持 PyTorch MLU 后端
|
||||
|
||||
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。
|
||||
|
||||
## 安装
|
||||
|
||||
请参考[快速入门文档](docs/zh_cn/get_started.md#installation)进行安装,参考[数据集准备](docs/zh_cn/dataset_prepare.md)处理数据。
|
||||
|
||||
## 快速入门
|
||||
|
||||
请参考[训练教程](docs/zh_cn/train.md)和[测试教程](docs/zh_cn/inference.md)学习 MMSegmentation 的基本使用。
|
||||
我们也提供了一些进阶教程,内容覆盖了:
|
||||
|
||||
- [增加自定义数据集](docs/zh_cn/tutorials/customize_datasets.md)
|
||||
- [设计新的数据预处理流程](docs/zh_cn/tutorials/data_pipeline.md)
|
||||
- [增加自定义模型](docs/zh_cn/tutorials/customize_models.md)
|
||||
- [增加自定义的运行时配置](docs/zh_cn/tutorials/customize_runtime.md)。
|
||||
- [训练技巧说明](docs/zh_cn/tutorials/training_tricks.md)
|
||||
- [有用的工具](docs/zh_cn/useful_tools.md)。
|
||||
|
||||
同时,我们提供了 Colab 教程。你可以在[这里](demo/MMSegmentation_Tutorial.ipynb)浏览教程,或者直接在 Colab 上[运行](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/master/demo/MMSegmentation_Tutorial.ipynb)。
|
||||
|
||||
## 基准测试和模型库
|
||||
|
||||
测试结果和模型可以在[模型库](docs/zh_cn/model_zoo.md)中找到。
|
||||
|
||||
已支持的骨干网络:
|
||||
|
||||
- [x] ResNet (CVPR'2016)
|
||||
- [x] ResNeXt (CVPR'2017)
|
||||
- [x] [HRNet (CVPR'2019)](configs/hrnet)
|
||||
- [x] [ResNeSt (ArXiv'2020)](configs/resnest)
|
||||
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
|
||||
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
|
||||
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
|
||||
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
|
||||
- [x] [Twins (NeurIPS'2021)](configs/twins)
|
||||
- [x] [BEiT (ICLR'2022)](configs/beit)
|
||||
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
|
||||
- [x] [MAE (CVPR'2022)](configs/mae)
|
||||
|
||||
已支持的算法:
|
||||
|
||||
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
|
||||
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
|
||||
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
|
||||
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
|
||||
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
|
||||
- [x] [BiSeNetV1 (ECCV'2018)](configs/bisenetv1)
|
||||
- [x] [PSANet (ECCV'2018)](configs/psanet)
|
||||
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
|
||||
- [x] [UPerNet (ECCV'2018)](configs/upernet)
|
||||
- [x] [ICNet (ECCV'2018)](configs/icnet)
|
||||
- [x] [NonLocal Net (CVPR'2018)](configs/nonlocal_net)
|
||||
- [x] [EncNet (CVPR'2018)](configs/encnet)
|
||||
- [x] [Semantic FPN (CVPR'2019)](configs/sem_fpn)
|
||||
- [x] [DANet (CVPR'2019)](configs/danet)
|
||||
- [x] [APCNet (CVPR'2019)](configs/apcnet)
|
||||
- [x] [EMANet (ICCV'2019)](configs/emanet)
|
||||
- [x] [CCNet (ICCV'2019)](configs/ccnet)
|
||||
- [x] [DMNet (ICCV'2019)](configs/dmnet)
|
||||
- [x] [ANN (ICCV'2019)](configs/ann)
|
||||
- [x] [GCNet (ICCVW'2019/TPAMI'2020)](configs/gcnet)
|
||||
- [x] [FastFCN (ArXiv'2019)](configs/fastfcn)
|
||||
- [x] [Fast-SCNN (ArXiv'2019)](configs/fastscnn)
|
||||
- [x] [ISANet (ArXiv'2019/IJCV'2021)](configs/isanet)
|
||||
- [x] [OCRNet (ECCV'2020)](configs/ocrnet)
|
||||
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
|
||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
|
||||
- [x] [STDC (CVPR'2021)](configs/stdc)
|
||||
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||
- [x] [DPT (ArXiv'2021)](configs/dpt)
|
||||
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
|
||||
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
|
||||
- [x] [K-Net (NeurIPS'2021)](configs/knet)
|
||||
|
||||
已支持的数据集:
|
||||
|
||||
- [x] [Cityscapes](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#cityscapes)
|
||||
- [x] [PASCAL VOC](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#pascal-voc)
|
||||
- [x] [ADE20K](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#ade20k)
|
||||
- [x] [Pascal Context](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#pascal-context)
|
||||
- [x] [COCO-Stuff 10k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#coco-stuff-10k)
|
||||
- [x] [COCO-Stuff 164k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#coco-stuff-164k)
|
||||
- [x] [CHASE_DB1](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#chase-db1)
|
||||
- [x] [DRIVE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#drive)
|
||||
- [x] [HRF](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#hrf)
|
||||
- [x] [STARE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#stare)
|
||||
- [x] [Dark Zurich](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#dark-zurich)
|
||||
- [x] [Nighttime Driving](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#nighttime-driving)
|
||||
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda)
|
||||
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
|
||||
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
|
||||
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isaid)
|
||||
|
||||
## 常见问题
|
||||
|
||||
如果遇到问题,请参考 [常见问题解答](docs/zh_cn/faq.md)。
|
||||
|
||||
## 贡献指南
|
||||
|
||||
我们感谢所有的贡献者为改进和提升 MMSegmentation 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
|
||||
|
||||
## 致谢
|
||||
|
||||
MMSegmentation 是一个由来自不同高校和企业的研发人员共同参与贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。 我们希望这个工具箱和基准测试可以为社区提供灵活的代码工具,供用户复现已有算法并开发自己的新模型,从而不断为开源社区提供贡献。
|
||||
|
||||
## 引用
|
||||
|
||||
如果你觉得本项目对你的研究工作有所帮助,请参考如下 bibtex 引用 MMSegmentation。
|
||||
|
||||
```bibtex
|
||||
@misc{mmseg2020,
|
||||
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
|
||||
author={MMSegmentation Contributors},
|
||||
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## 开源许可证
|
||||
|
||||
`MMSegmentation` 目前以 Apache 2.0 的许可证发布,但是其中有一部分功能并不是使用的 Apache2.0 许可证,我们在 [许可证](LICENSES.md) 中详细地列出了这些功能以及他们对应的许可证,如果您正在从事盈利性活动,请谨慎参考此文档。
|
||||
|
||||
## OpenMMLab 的其他项目
|
||||
|
||||
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
|
||||
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
|
||||
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
|
||||
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
|
||||
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
|
||||
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
|
||||
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
|
||||
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
|
||||
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
|
||||
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
|
||||
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
|
||||
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
|
||||
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
|
||||
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
|
||||
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
|
||||
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
|
||||
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
|
||||
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
|
||||
|
||||
## 欢迎加入 OpenMMLab 社区
|
||||
|
||||
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 [OpenMMLab 团队](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 以及 [MMSegmentation](https://jq.qq.com/?_wv=1027&k=9sprS2YO) 的 QQ 群。
|
||||
|
||||
<div align="center">
|
||||
<img src="docs/zh_cn/imgs/zhihu_qrcode.jpg" height="400" /> <img src="docs/zh_cn/imgs/qq_group_qrcode.jpg" height="400" />
|
||||
</div>
|
||||
|
||||
我们会在 OpenMMLab 社区为大家
|
||||
|
||||
- 📢 分享 AI 框架的前沿核心技术
|
||||
- 💻 解读 PyTorch 常用模块源码
|
||||
- 📰 发布 OpenMMLab 的相关新闻
|
||||
- 🚀 介绍 OpenMMLab 开发的前沿算法
|
||||
- 🏃 获取更高效的问题答疑和意见反馈
|
||||
- 🔥 提供与各行各业开发者充分交流的平台
|
||||
|
||||
干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
|
||||
@ -1,53 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CustomDataset' # need to change
|
||||
data_root = 'data/my_dataset_v7' # need to change
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.93135507, 116.76565979, 103.67335042], std=[49.55883976, 47.7692082, 50.7934459], to_rgb=True) # need to calculate
|
||||
crop_size = (512, 512) # need to change
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(512, 512)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512), # need to change
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=2, # need to change
|
||||
workers_per_gpu=1, # need to change
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/train', # need to change
|
||||
ann_dir='ann_dir/train', # need to change
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',# need to change
|
||||
ann_dir='ann_dir/val',# need to change
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/test',# need to change
|
||||
ann_dir='ann_dir/test',# need to change
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ADE20KDataset'
|
||||
data_root = 'data/ade/ADEChallengeData2016'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ADE20KDataset'
|
||||
data_root = 'data/ade/ADEChallengeData2016'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (640, 640)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2560, 640),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,59 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ChaseDB1Dataset'
|
||||
data_root = 'data/CHASE_DB1'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (960, 999)
|
||||
crop_size = (128, 128)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_root = 'data/cityscapes/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/train',
|
||||
ann_dir='gtFine/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/val',
|
||||
ann_dir='gtFine/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/val',
|
||||
ann_dir='gtFine/val',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,35 +0,0 @@
|
||||
_base_ = './cityscapes.py'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (1024, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
@ -1,35 +0,0 @@
|
||||
_base_ = './cityscapes.py'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (768, 768)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2049, 1025),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
@ -1,35 +0,0 @@
|
||||
_base_ = './cityscapes.py'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (769, 769)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2049, 1025),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
@ -1,35 +0,0 @@
|
||||
_base_ = './cityscapes.py'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (832, 832)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
@ -1,57 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'COCOStuffDataset'
|
||||
data_root = 'data/coco_stuff10k'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
reduce_zero_label=True,
|
||||
img_dir='images/train2014',
|
||||
ann_dir='annotations/train2014',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
reduce_zero_label=True,
|
||||
img_dir='images/test2014',
|
||||
ann_dir='annotations/test2014',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
reduce_zero_label=True,
|
||||
img_dir='images/test2014',
|
||||
ann_dir='annotations/test2014',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'COCOStuffDataset'
|
||||
data_root = 'data/coco_stuff164k'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/train2017',
|
||||
ann_dir='annotations/train2017',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/val2017',
|
||||
ann_dir='annotations/val2017',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/val2017',
|
||||
ann_dir='annotations/val2017',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,59 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'DRIVEDataset'
|
||||
data_root = 'data/DRIVE'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (584, 565)
|
||||
crop_size = (64, 64)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,59 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'HRFDataset'
|
||||
data_root = 'data/HRF'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (2336, 3504)
|
||||
crop_size = (256, 256)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,62 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'iSAIDDataset'
|
||||
data_root = 'data/iSAID'
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
"""
|
||||
This crop_size setting is followed by the implementation of
|
||||
`PointFlow: Flowing Semantics Through Points for Aerial Image
|
||||
Segmentation <https://arxiv.org/pdf/2103.06564.pdf>`_.
|
||||
"""
|
||||
|
||||
crop_size = (896, 896)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(896, 896),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/train',
|
||||
ann_dir='ann_dir/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'LoveDADataset'
|
||||
data_root = 'data/loveDA'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1024, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/train',
|
||||
ann_dir='ann_dir/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,60 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'PascalContextDataset'
|
||||
data_root = 'data/VOCdevkit/VOC2010/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
img_scale = (520, 520)
|
||||
crop_size = (480, 480)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/train.txt',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,60 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'PascalContextDataset59'
|
||||
data_root = 'data/VOCdevkit/VOC2010/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
img_scale = (520, 520)
|
||||
crop_size = (480, 480)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/train.txt',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClassContext',
|
||||
split='ImageSets/SegmentationContext/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,57 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'PascalVOCDataset'
|
||||
data_root = 'data/VOCdevkit/VOC2012'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClass',
|
||||
split='ImageSets/Segmentation/train.txt',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClass',
|
||||
split='ImageSets/Segmentation/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='JPEGImages',
|
||||
ann_dir='SegmentationClass',
|
||||
split='ImageSets/Segmentation/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,9 +0,0 @@
|
||||
_base_ = './pascal_voc12.py'
|
||||
# dataset settings
|
||||
data = dict(
|
||||
train=dict(
|
||||
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
|
||||
split=[
|
||||
'ImageSets/Segmentation/train.txt',
|
||||
'ImageSets/Segmentation/aug.txt'
|
||||
]))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'PotsdamDataset'
|
||||
data_root = 'data/potsdam'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/train',
|
||||
ann_dir='ann_dir/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,59 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'STAREDataset'
|
||||
data_root = 'data/STARE'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (605, 700)
|
||||
crop_size = (128, 128)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,54 +0,0 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ISPRSDataset'
|
||||
data_root = 'data/vaihingen'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/train',
|
||||
ann_dir='ann_dir/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='img_dir/val',
|
||||
ann_dir='ann_dir/val',
|
||||
pipeline=test_pipeline))
|
||||
@ -1,9 +0,0 @@
|
||||
# yapf:enable
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
#load_from = 'checkpoints/danet_r50-d8_512x512_80k_ade20k_20200615_015125-edb18e08.pth'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
#workflow = [('train', 1)]
|
||||
workflow = [('train', 1), ('val', 1)]
|
||||
cudnn_benchmark = True
|
||||
@ -1,46 +0,0 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='ANNHead',
|
||||
in_channels=[1024, 2048],
|
||||
in_index=[2, 3],
|
||||
channels=512,
|
||||
project_channels=256,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
@ -1,44 +0,0 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='APCHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
@ -1,68 +0,0 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='BiSeNetV1',
|
||||
in_channels=3,
|
||||
context_channels=(128, 256, 512),
|
||||
spatial_channels=(64, 64, 64, 128),
|
||||
out_indices=(0, 1, 2),
|
||||
out_channels=256,
|
||||
backbone_cfg=dict(
|
||||
type='ResNet',
|
||||
in_channels=3,
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 1, 1),
|
||||
strides=(1, 2, 2, 2),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
init_cfg=None),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=256,
|
||||
in_index=0,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=[
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=1,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=2,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
],
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
Some files were not shown because too many files have changed in this diff Show More
불러오는 중...
Reference in New Issue
Block a user