feat: 예측 개선, 관리자 패널 추가, 보고서 기능 개선

This commit is contained in:
jeonghyo.k 2026-03-18 18:10:41 +09:00
부모 c7c7537dbb
커밋 621d8e3516
1303개의 변경된 파일748개의 추가작업 그리고 88691개의 파일을 삭제

파일 보기

@ -474,6 +474,7 @@ interface TrajectoryResult {
centerPoints: Array<{ lat: number; lon: number; time: number; model: string }>; centerPoints: Array<{ lat: number; lon: number; time: number; model: string }>;
windDataByModel: Record<string, TrajectoryWindPoint[][]>; windDataByModel: Record<string, TrajectoryWindPoint[][]>;
hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]>; hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]>;
summaryByModel: Record<string, SingleModelTrajectoryResult['summary']>;
} }
function transformTrajectoryResult(rawResult: TrajectoryTimeStep[], model: string): SingleModelTrajectoryResult { function transformTrajectoryResult(rawResult: TrajectoryTimeStep[], model: string): SingleModelTrajectoryResult {
@ -531,6 +532,7 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
let baseResult: SingleModelTrajectoryResult | null = null; let baseResult: SingleModelTrajectoryResult | null = null;
const windDataByModel: Record<string, TrajectoryWindPoint[][]> = {}; const windDataByModel: Record<string, TrajectoryWindPoint[][]> = {};
const hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]> = {}; const hydrDataByModel: Record<string, ({ value: [number[][], number[][]]; grid: TrajectoryHydrGrid } | null)[]> = {};
const summaryByModel: Record<string, SingleModelTrajectoryResult['summary']> = {};
// OpenDrift 우선, 없으면 POSEIDON 선택 (ORDER BY CMPL_DTM DESC이므로 첫 번째 행이 가장 최근) // OpenDrift 우선, 없으면 POSEIDON 선택 (ORDER BY CMPL_DTM DESC이므로 첫 번째 행이 가장 최근)
const opendriftRow = (rows as Array<Record<string, unknown>>).find((r) => r['algo_cd'] === 'OPENDRIFT'); const opendriftRow = (rows as Array<Record<string, unknown>>).find((r) => r['algo_cd'] === 'OPENDRIFT');
@ -546,6 +548,7 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
allCenterPoints = allCenterPoints.concat(parsed.centerPoints); allCenterPoints = allCenterPoints.concat(parsed.centerPoints);
windDataByModel[modelName] = parsed.windData; windDataByModel[modelName] = parsed.windData;
hydrDataByModel[modelName] = parsed.hydrData; hydrDataByModel[modelName] = parsed.hydrData;
summaryByModel[modelName] = parsed.summary;
if (row === baseRow) { if (row === baseRow) {
baseResult = parsed; baseResult = parsed;
@ -560,6 +563,7 @@ export async function getAnalysisTrajectory(acdntSn: number): Promise<Trajectory
centerPoints: allCenterPoints, centerPoints: allCenterPoints,
windDataByModel, windDataByModel,
hydrDataByModel, hydrDataByModel,
summaryByModel,
}; };
} }

파일 보기

@ -278,7 +278,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
(1, 'incidents', 'READ', 'Y'), (1, 'incidents', 'CREATE', 'Y'), (1, 'incidents', 'UPDATE', 'Y'), (1, 'incidents', 'DELETE', 'Y'), (1, 'incidents', 'READ', 'Y'), (1, 'incidents', 'CREATE', 'Y'), (1, 'incidents', 'UPDATE', 'Y'), (1, 'incidents', 'DELETE', 'Y'),
(1, 'board', 'READ', 'Y'), (1, 'board', 'CREATE', 'Y'), (1, 'board', 'UPDATE', 'Y'), (1, 'board', 'DELETE', 'Y'), (1, 'board', 'READ', 'Y'), (1, 'board', 'CREATE', 'Y'), (1, 'board', 'UPDATE', 'Y'), (1, 'board', 'DELETE', 'Y'),
(1, 'weather', 'READ', 'Y'), (1, 'weather', 'CREATE', 'Y'), (1, 'weather', 'UPDATE', 'Y'), (1, 'weather', 'DELETE', 'Y'), (1, 'weather', 'READ', 'Y'), (1, 'weather', 'CREATE', 'Y'), (1, 'weather', 'UPDATE', 'Y'), (1, 'weather', 'DELETE', 'Y'),
(1, 'admin', 'READ', 'Y'), (1, 'admin', 'CREATE', 'Y'), (1, 'admin', 'UPDATE', 'Y'), (1, 'admin', 'DELETE', 'Y'); (1, 'admin', 'READ', 'Y'), (1, 'admin', 'CREATE', 'Y'), (1, 'admin', 'UPDATE', 'Y'), (1, 'admin', 'DELETE', 'Y'),
(1, 'monitor', 'READ', 'Y');
-- HQ_CLEANUP (ROLE_SN=2): 방제 관련 탭 RCUD + 기타 탭 READ/CREATE, admin 제외 -- HQ_CLEANUP (ROLE_SN=2): 방제 관련 탭 RCUD + 기타 탭 READ/CREATE, admin 제외
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
@ -292,7 +293,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
(2, 'incidents', 'READ', 'Y'), (2, 'incidents', 'CREATE', 'Y'), (2, 'incidents', 'UPDATE', 'Y'), (2, 'incidents', 'DELETE', 'Y'), (2, 'incidents', 'READ', 'Y'), (2, 'incidents', 'CREATE', 'Y'), (2, 'incidents', 'UPDATE', 'Y'), (2, 'incidents', 'DELETE', 'Y'),
(2, 'board', 'READ', 'Y'), (2, 'board', 'CREATE', 'Y'), (2, 'board', 'UPDATE', 'Y'), (2, 'board', 'READ', 'Y'), (2, 'board', 'CREATE', 'Y'), (2, 'board', 'UPDATE', 'Y'),
(2, 'weather', 'READ', 'Y'), (2, 'weather', 'CREATE', 'Y'), (2, 'weather', 'READ', 'Y'), (2, 'weather', 'CREATE', 'Y'),
(2, 'admin', 'READ', 'N'); (2, 'admin', 'READ', 'N'),
(2, 'monitor', 'READ', 'Y');
-- MANAGER (ROLE_SN=3): admin 탭 제외, RCUD 허용 -- MANAGER (ROLE_SN=3): admin 탭 제외, RCUD 허용
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
@ -306,7 +308,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
(3, 'incidents', 'READ', 'Y'), (3, 'incidents', 'CREATE', 'Y'), (3, 'incidents', 'UPDATE', 'Y'), (3, 'incidents', 'DELETE', 'Y'), (3, 'incidents', 'READ', 'Y'), (3, 'incidents', 'CREATE', 'Y'), (3, 'incidents', 'UPDATE', 'Y'), (3, 'incidents', 'DELETE', 'Y'),
(3, 'board', 'READ', 'Y'), (3, 'board', 'CREATE', 'Y'), (3, 'board', 'UPDATE', 'Y'), (3, 'board', 'DELETE', 'Y'), (3, 'board', 'READ', 'Y'), (3, 'board', 'CREATE', 'Y'), (3, 'board', 'UPDATE', 'Y'), (3, 'board', 'DELETE', 'Y'),
(3, 'weather', 'READ', 'Y'), (3, 'weather', 'CREATE', 'Y'), (3, 'weather', 'UPDATE', 'Y'), (3, 'weather', 'DELETE', 'Y'), (3, 'weather', 'READ', 'Y'), (3, 'weather', 'CREATE', 'Y'), (3, 'weather', 'UPDATE', 'Y'), (3, 'weather', 'DELETE', 'Y'),
(3, 'admin', 'READ', 'N'); (3, 'admin', 'READ', 'N'),
(3, 'monitor', 'READ', 'Y');
-- USER (ROLE_SN=4): assets/admin 제외, 허용 탭은 READ/CREATE/UPDATE, DELETE 없음 -- USER (ROLE_SN=4): assets/admin 제외, 허용 탭은 READ/CREATE/UPDATE, DELETE 없음
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
@ -320,7 +323,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
(4, 'incidents', 'READ', 'Y'), (4, 'incidents', 'CREATE', 'Y'), (4, 'incidents', 'UPDATE', 'Y'), (4, 'incidents', 'READ', 'Y'), (4, 'incidents', 'CREATE', 'Y'), (4, 'incidents', 'UPDATE', 'Y'),
(4, 'board', 'READ', 'Y'), (4, 'board', 'CREATE', 'Y'), (4, 'board', 'UPDATE', 'Y'), (4, 'board', 'READ', 'Y'), (4, 'board', 'CREATE', 'Y'), (4, 'board', 'UPDATE', 'Y'),
(4, 'weather', 'READ', 'Y'), (4, 'weather', 'READ', 'Y'),
(4, 'admin', 'READ', 'N'); (4, 'admin', 'READ', 'N'),
(4, 'monitor', 'READ', 'Y');
-- VIEWER (ROLE_SN=5): 제한적 탭의 READ만 허용 -- VIEWER (ROLE_SN=5): 제한적 탭의 READ만 허용
INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
@ -334,7 +338,8 @@ INSERT INTO AUTH_PERM (ROLE_SN, RSRC_CD, OPER_CD, GRANT_YN) VALUES
(5, 'incidents', 'READ', 'Y'), (5, 'incidents', 'READ', 'Y'),
(5, 'board', 'READ', 'Y'), (5, 'board', 'READ', 'Y'),
(5, 'weather', 'READ', 'Y'), (5, 'weather', 'READ', 'Y'),
(5, 'admin', 'READ', 'N'); (5, 'admin', 'READ', 'N'),
(5, 'monitor', 'READ', 'Y');
-- ============================================================ -- ============================================================

파일 보기

@ -39,9 +39,13 @@ export function TopBar({ activeTab, onTabChange }: TopBarProps) {
{/* Left Section */} {/* Left Section */}
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
{/* Logo */} {/* Logo */}
<div className="flex items-center"> <button
onClick={() => tabs[0] && onTabChange(tabs[0].id as MainTab)}
className="flex items-center hover:opacity-80 transition-opacity cursor-pointer"
title="홈으로 이동"
>
<img src="/wing_logo_white.svg" alt="WING 해양환경 위기대응" className="h-3.5" /> <img src="/wing_logo_white.svg" alt="WING 해양환경 위기대응" className="h-3.5" />
</div> </button>
{/* Divider */} {/* Divider */}
<div className="w-px h-6 bg-border-light" /> <div className="w-px h-6 bg-border-light" />

파일 보기

@ -8,6 +8,8 @@ import MenusPanel from './MenusPanel';
import SettingsPanel from './SettingsPanel'; import SettingsPanel from './SettingsPanel';
import BoardMgmtPanel from './BoardMgmtPanel'; import BoardMgmtPanel from './BoardMgmtPanel';
import VesselSignalPanel from './VesselSignalPanel'; import VesselSignalPanel from './VesselSignalPanel';
import CleanupEquipPanel from './CleanupEquipPanel';
import AssetUploadPanel from './AssetUploadPanel';
/** 기존 패널이 있는 메뉴 ID 매핑 */ /** 기존 패널이 있는 메뉴 ID 매핑 */
const PANEL_MAP: Record<string, () => JSX.Element> = { const PANEL_MAP: Record<string, () => JSX.Element> = {
@ -19,6 +21,8 @@ const PANEL_MAP: Record<string, () => JSX.Element> = {
board: () => <BoardMgmtPanel initialCategory="DATA" />, board: () => <BoardMgmtPanel initialCategory="DATA" />,
qna: () => <BoardMgmtPanel initialCategory="QNA" />, qna: () => <BoardMgmtPanel initialCategory="QNA" />,
'collect-vessel-signal': () => <VesselSignalPanel />, 'collect-vessel-signal': () => <VesselSignalPanel />,
'cleanup-equip': () => <CleanupEquipPanel />,
'asset-upload': () => <AssetUploadPanel />,
}; };
export function AdminView() { export function AdminView() {

파일 보기

@ -0,0 +1,257 @@
import { useState, useEffect, useRef } from 'react';
import { fetchUploadLogs } from '@tabs/assets/services/assetsApi';
import type { UploadLogItem } from '@tabs/assets/services/assetsApi';
const ASSET_CATEGORIES = ['전체', '방제선', '유회수기', '이송펌프', '방제차량', '살포장치', '오일붐', '흡착재', '기타'];
const JURISDICTIONS = ['전체', '남해청', '서해청', '중부청', '동해청', '제주청'];
const PERM_ITEMS = [
{ icon: '👑', role: '시스템관리자', desc: '전체 자산 업로드/삭제 가능', bg: 'rgba(245,158,11,0.15)', color: 'text-yellow-400' },
{ icon: '🔧', role: '운영관리자', desc: '관할청 내 자산 업로드 가능', bg: 'rgba(6,182,212,0.15)', color: 'text-primary-cyan' },
{ icon: '👁', role: '조회자', desc: '현황 조회만 가능', bg: 'rgba(148,163,184,0.15)', color: 'text-text-2' },
{ icon: '🚫', role: '게스트', desc: '접근 불가', bg: 'rgba(239,68,68,0.1)', color: 'text-red-400' },
];
function formatDate(dtm: string) {
const d = new Date(dtm);
if (isNaN(d.getTime())) return dtm;
return `${d.getFullYear()}-${String(d.getMonth() + 1).padStart(2, '0')}-${String(d.getDate()).padStart(2, '0')}`;
}
function AssetUploadPanel() {
const [assetCategory, setAssetCategory] = useState('전체');
const [jurisdiction, setJurisdiction] = useState('전체');
const [uploadMode, setUploadMode] = useState<'add' | 'replace'>('add');
const [uploaded, setUploaded] = useState(false);
const [uploadHistory, setUploadHistory] = useState<UploadLogItem[]>([]);
const [dragging, setDragging] = useState(false);
const [selectedFile, setSelectedFile] = useState<File | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const resetTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => {
fetchUploadLogs(10)
.then(setUploadHistory)
.catch(err => console.error('[AssetUploadPanel] 이력 로드 실패:', err));
}, []);
useEffect(() => {
return () => {
if (resetTimerRef.current) clearTimeout(resetTimerRef.current);
};
}, []);
const handleFileSelect = (file: File | null) => {
if (!file) return;
const ext = file.name.split('.').pop()?.toLowerCase();
if (ext !== 'xlsx' && ext !== 'csv') return;
setSelectedFile(file);
};
const handleDrop = (e: React.DragEvent) => {
e.preventDefault();
setDragging(false);
const file = e.dataTransfer.files[0] ?? null;
handleFileSelect(file);
};
const handleUpload = () => {
if (!selectedFile) return;
setUploaded(true);
resetTimerRef.current = setTimeout(() => {
setUploaded(false);
setSelectedFile(null);
}, 3000);
};
return (
<div className="flex flex-col h-full">
{/* 헤더 */}
<div className="px-6 py-4 border-b border-border flex-shrink-0">
<h1 className="text-lg font-bold text-text-1 font-korean"> </h1>
<p className="text-xs text-text-3 mt-1 font-korean"> </p>
</div>
{/* 본문 */}
<div className="flex-1 overflow-auto p-6">
<div className="flex gap-6 h-full">
{/* 좌측: 파일 업로드 */}
<div className="flex-1 max-w-[560px] space-y-4">
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
<div className="px-5 py-3 border-b border-border">
<h2 className="text-sm font-bold text-text-1 font-korean"> </h2>
</div>
<div className="px-5 py-4 space-y-4">
{/* 드롭존 */}
<div
onDragOver={e => { e.preventDefault(); setDragging(true); }}
onDragLeave={() => setDragging(false)}
onDrop={handleDrop}
onClick={() => fileInputRef.current?.click()}
className={`rounded-lg border-2 border-dashed py-8 text-center cursor-pointer transition-colors ${
dragging
? 'border-primary-cyan bg-[rgba(6,182,212,0.05)]'
: 'border-border hover:border-primary-cyan/50 bg-bg-2'
}`}
>
<div className="text-3xl mb-2 opacity-40">📁</div>
{selectedFile ? (
<div className="text-xs font-semibold text-primary-cyan font-korean mb-1">{selectedFile.name}</div>
) : (
<>
<div className="text-xs font-semibold text-text-2 font-korean mb-1"> </div>
<div className="text-[10px] text-text-3 font-korean mb-3">(.xlsx), CSV · 10MB</div>
<button
type="button"
className="px-4 py-1.5 text-xs font-semibold rounded-md bg-primary-cyan text-bg-0
hover:shadow-[0_0_12px_rgba(6,182,212,0.3)] transition-all font-korean"
onClick={e => { e.stopPropagation(); fileInputRef.current?.click(); }}
>
</button>
</>
)}
<input
ref={fileInputRef}
type="file"
accept=".xlsx,.csv"
className="hidden"
onChange={e => handleFileSelect(e.target.files?.[0] ?? null)}
/>
</div>
{/* 자산 분류 */}
<div>
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5"> </label>
<select
value={assetCategory}
onChange={e => setAssetCategory(e.target.value)}
className="w-full px-3 py-2 text-xs bg-bg-2 border border-border rounded-md
text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
>
{ASSET_CATEGORIES.map(c => (
<option key={c} value={c}>{c}</option>
))}
</select>
</div>
{/* 대상 관할 */}
<div>
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5"> </label>
<select
value={jurisdiction}
onChange={e => setJurisdiction(e.target.value)}
className="w-full px-3 py-2 text-xs bg-bg-2 border border-border rounded-md
text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
>
{JURISDICTIONS.map(j => (
<option key={j} value={j}>{j}</option>
))}
</select>
</div>
{/* 업로드 방식 */}
<div>
<label className="block text-[11px] font-semibold text-text-2 font-korean mb-1.5"> </label>
<div className="flex gap-4">
<label className="flex items-center gap-1.5 cursor-pointer text-xs text-text-2 font-korean">
<input
type="radio"
checked={uploadMode === 'add'}
onChange={() => setUploadMode('add')}
className="accent-primary-cyan"
/>
( + )
</label>
<label className="flex items-center gap-1.5 cursor-pointer text-xs text-text-2 font-korean">
<input
type="radio"
checked={uploadMode === 'replace'}
onChange={() => setUploadMode('replace')}
className="accent-primary-cyan"
/>
( )
</label>
</div>
</div>
{/* 업로드 버튼 */}
<button
type="button"
onClick={handleUpload}
disabled={!selectedFile || uploaded}
className={`w-full py-2.5 text-xs font-semibold rounded-md transition-all font-korean disabled:opacity-50 ${
uploaded
? 'bg-[rgba(34,197,94,0.15)] text-status-green border border-status-green/30'
: 'bg-primary-cyan text-bg-0 hover:shadow-[0_0_12px_rgba(6,182,212,0.3)]'
}`}
>
{uploaded ? '✅ 업로드 완료!' : '📤 업로드 실행'}
</button>
</div>
</div>
</div>
{/* 우측 */}
<div className="w-[400px] space-y-4 flex-shrink-0">
{/* 수정 권한 체계 */}
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
<div className="px-5 py-3 border-b border-border">
<h2 className="text-sm font-bold text-text-1 font-korean"> </h2>
</div>
<div className="px-5 py-4 space-y-2">
{PERM_ITEMS.map(p => (
<div
key={p.role}
className="flex items-center gap-3 px-4 py-3 bg-bg-2 border border-border rounded-md"
>
<div
className="w-8 h-8 rounded-full flex items-center justify-center text-sm flex-shrink-0"
style={{ background: p.bg }}
>
{p.icon}
</div>
<div>
<div className={`text-xs font-bold font-korean ${p.color}`}>{p.role}</div>
<div className="text-[10px] text-text-3 font-korean mt-0.5">{p.desc}</div>
</div>
</div>
))}
</div>
</div>
{/* 최근 업로드 이력 */}
<div className="rounded-lg border border-border bg-bg-1 overflow-hidden">
<div className="px-5 py-3 border-b border-border">
<h2 className="text-sm font-bold text-text-1 font-korean"> </h2>
</div>
<div className="px-5 py-4 space-y-2">
{uploadHistory.length === 0 ? (
<div className="text-[11px] text-text-3 font-korean text-center py-4"> .</div>
) : uploadHistory.map(h => (
<div
key={h.logSn}
className="flex justify-between items-center px-4 py-3 bg-bg-2 border border-border rounded-md"
>
<div>
<div className="text-xs font-semibold text-text-1 font-korean">{h.fileNm}</div>
<div className="text-[10px] text-text-3 mt-0.5 font-korean">
{formatDate(h.regDtm)} · {h.uploaderNm} · {h.uploadCnt.toLocaleString()}
</div>
</div>
<span className="px-2 py-0.5 rounded-full text-[10px] font-semibold
bg-[rgba(34,197,94,0.15)] text-status-green flex-shrink-0">
</span>
</div>
))}
</div>
</div>
</div>
</div>
</div>
</div>
);
}
export default AssetUploadPanel;

파일 보기

@ -0,0 +1,230 @@
import { useState, useEffect, useMemo } from 'react';
import { fetchOrganizations } from '@tabs/assets/services/assetsApi';
import type { AssetOrgCompat } from '@tabs/assets/services/assetsApi';
import { typeTagCls } from '@tabs/assets/components/assetTypes';
const PAGE_SIZE = 20;
const regionShort = (j: string) =>
j.includes('남해') ? '남해청' : j.includes('서해') ? '서해청' :
j.includes('중부') ? '중부청' : j.includes('동해') ? '동해청' :
j.includes('제주') ? '제주청' : j;
function CleanupEquipPanel() {
const [organizations, setOrganizations] = useState<AssetOrgCompat[]>([]);
const [loading, setLoading] = useState(true);
const [searchTerm, setSearchTerm] = useState('');
const [regionFilter, setRegionFilter] = useState('전체');
const [typeFilter, setTypeFilter] = useState('전체');
const [currentPage, setCurrentPage] = useState(1);
const load = () => {
setLoading(true);
fetchOrganizations()
.then(setOrganizations)
.catch(err => console.error('[CleanupEquipPanel] 데이터 로드 실패:', err))
.finally(() => setLoading(false));
};
useEffect(() => {
let cancelled = false;
fetchOrganizations()
.then(data => { if (!cancelled) setOrganizations(data); })
.catch(err => console.error('[CleanupEquipPanel] 데이터 로드 실패:', err))
.finally(() => { if (!cancelled) setLoading(false); });
return () => { cancelled = true; };
}, []);
const typeOptions = useMemo(() => {
const set = new Set(organizations.map(o => o.type));
return Array.from(set).sort();
}, [organizations]);
const filtered = useMemo(() =>
organizations
.filter(o => regionFilter === '전체' || o.jurisdiction.includes(regionFilter))
.filter(o => typeFilter === '전체' || o.type === typeFilter)
.filter(o => !searchTerm || o.name.includes(searchTerm) || o.address.includes(searchTerm)),
[organizations, regionFilter, typeFilter, searchTerm]
);
const totalPages = Math.max(1, Math.ceil(filtered.length / PAGE_SIZE));
const safePage = Math.min(currentPage, totalPages);
const paged = filtered.slice((safePage - 1) * PAGE_SIZE, safePage * PAGE_SIZE);
const handleFilterChange = (setter: (v: string) => void) => (e: React.ChangeEvent<HTMLSelectElement>) => {
setter(e.target.value);
setCurrentPage(1);
};
const pageNumbers = (() => {
const range: number[] = [];
const start = Math.max(1, safePage - 2);
const end = Math.min(totalPages, safePage + 2);
for (let i = start; i <= end; i++) range.push(i);
return range;
})();
return (
<div className="flex flex-col h-full">
{/* 헤더 */}
<div className="flex items-center justify-between px-6 py-4 border-b border-border">
<div>
<h1 className="text-lg font-bold text-text-1 font-korean"> </h1>
<p className="text-xs text-text-3 mt-1 font-korean"> {filtered.length} </p>
</div>
<div className="flex items-center gap-3">
<select
value={regionFilter}
onChange={handleFilterChange(setRegionFilter)}
className="px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
>
<option value="전체"> </option>
<option value="남해"></option>
<option value="서해"></option>
<option value="중부"></option>
<option value="동해"></option>
<option value="제주"></option>
</select>
<select
value={typeFilter}
onChange={handleFilterChange(setTypeFilter)}
className="px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 focus:border-primary-cyan focus:outline-none font-korean"
>
<option value="전체"> </option>
{typeOptions.map(t => (
<option key={t} value={t}>{t}</option>
))}
</select>
<input
type="text"
placeholder="기관명, 주소 검색..."
value={searchTerm}
onChange={e => { setSearchTerm(e.target.value); setCurrentPage(1); }}
className="w-56 px-3 py-2 text-xs bg-bg-2 border border-border rounded-md text-text-1 placeholder-text-3 focus:border-primary-cyan focus:outline-none font-korean"
/>
<button
onClick={load}
className="px-4 py-2 text-xs font-semibold rounded-md bg-bg-2 border border-border text-text-2 hover:border-primary-cyan hover:text-primary-cyan transition-all font-korean"
>
</button>
</div>
</div>
{/* 테이블 */}
<div className="flex-1 overflow-auto">
{loading ? (
<div className="flex items-center justify-center h-32 text-text-3 text-sm font-korean">
...
</div>
) : (
<table className="w-full">
<thead>
<tr className="border-b border-border bg-bg-1">
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean w-10"></th>
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-left text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
<th className="px-4 py-3 text-center text-[11px] font-semibold text-text-3 font-korean"></th>
</tr>
</thead>
<tbody>
{paged.length === 0 ? (
<tr>
<td colSpan={11} className="px-6 py-10 text-center text-xs text-text-3 font-korean">
.
</td>
</tr>
) : paged.map((org, idx) => (
<tr key={org.id} className="border-b border-border hover:bg-[rgba(255,255,255,0.02)] transition-colors">
<td className="px-4 py-3 text-[11px] text-text-3 font-mono text-center">
{(safePage - 1) * PAGE_SIZE + idx + 1}
</td>
<td className="px-4 py-3">
<span className={`text-[10px] px-1.5 py-0.5 rounded font-bold font-korean ${typeTagCls(org.type)}`}>
{org.type}
</span>
</td>
<td className="px-4 py-3 text-[11px] text-text-2 font-korean">
{regionShort(org.jurisdiction)}
</td>
<td className="px-4 py-3 text-[11px] text-text-1 font-korean font-semibold">
{org.name}
</td>
<td className="px-4 py-3 text-[11px] text-text-3 font-korean max-w-[200px] truncate">
{org.address}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
{org.vessel > 0 ? <span className="text-text-1">{org.vessel}</span> : <span className="text-text-3"></span>}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
{org.skimmer > 0 ? <span className="text-text-1">{org.skimmer}</span> : <span className="text-text-3"></span>}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
{org.pump > 0 ? <span className="text-text-1">{org.pump}</span> : <span className="text-text-3"></span>}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
{org.vehicle > 0 ? <span className="text-text-1">{org.vehicle}</span> : <span className="text-text-3"></span>}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center text-text-2">
{org.sprayer > 0 ? <span className="text-text-1">{org.sprayer}</span> : <span className="text-text-3"></span>}
</td>
<td className="px-4 py-3 text-[11px] font-mono text-center font-bold text-primary-cyan">
{org.totalAssets.toLocaleString()}
</td>
</tr>
))}
</tbody>
</table>
)}
</div>
{/* 페이지네이션 */}
{!loading && filtered.length > 0 && (
<div className="flex items-center justify-between px-6 py-3 border-t border-border">
<span className="text-[11px] text-text-3 font-korean">
{(safePage - 1) * PAGE_SIZE + 1}{Math.min(safePage * PAGE_SIZE, filtered.length)} / {filtered.length}
</span>
<div className="flex items-center gap-1">
<button
onClick={() => setCurrentPage(p => Math.max(1, p - 1))}
disabled={safePage === 1}
className="px-2.5 py-1 text-[11px] border border-border rounded text-text-2 hover:border-primary-cyan hover:text-primary-cyan disabled:opacity-40 transition-colors"
>
&lt;
</button>
{pageNumbers.map(p => (
<button
key={p}
onClick={() => setCurrentPage(p)}
className="px-2.5 py-1 text-[11px] border rounded transition-colors"
style={p === safePage
? { borderColor: 'var(--cyan)', color: 'var(--cyan)', background: 'rgba(6,182,212,0.1)' }
: { borderColor: 'var(--border)', color: 'var(--text-2)' }
}
>
{p}
</button>
))}
<button
onClick={() => setCurrentPage(p => Math.min(totalPages, p + 1))}
disabled={safePage === totalPages}
className="px-2.5 py-1 text-[11px] border border-border rounded text-text-2 hover:border-primary-cyan hover:text-primary-cyan disabled:opacity-40 transition-colors"
>
&gt;
</button>
</div>
</div>
)}
</div>
);
}
export default CleanupEquipPanel;

파일 보기

@ -294,6 +294,7 @@ interface RolePermTabProps {
setSelectedRoleSn: (sn: number | null) => void setSelectedRoleSn: (sn: number | null) => void
dirty: boolean dirty: boolean
saving: boolean saving: boolean
saveError: string | null
handleSave: () => Promise<void> handleSave: () => Promise<void>
handleToggleExpand: (code: string) => void handleToggleExpand: (code: string) => void
handleTogglePerm: (code: string, oper: OperCode, currentState: PermState) => void handleTogglePerm: (code: string, oper: OperCode, currentState: PermState) => void
@ -328,6 +329,7 @@ function RolePermTab({
setSelectedRoleSn, setSelectedRoleSn,
dirty, dirty,
saving, saving,
saveError,
handleSave, handleSave,
handleToggleExpand, handleToggleExpand,
handleTogglePerm, handleTogglePerm,
@ -378,6 +380,9 @@ function RolePermTab({
> >
{saving ? '저장 중...' : '변경사항 저장'} {saving ? '저장 중...' : '변경사항 저장'}
</button> </button>
{saveError && (
<span className="text-[11px] text-status-red font-korean">{saveError}</span>
)}
</div> </div>
{/* 역할 탭 바 */} {/* 역할 탭 바 */}
@ -861,6 +866,7 @@ function PermissionsPanel() {
const [permTree, setPermTree] = useState<PermTreeNode[]>([]) const [permTree, setPermTree] = useState<PermTreeNode[]>([])
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
const [saving, setSaving] = useState(false) const [saving, setSaving] = useState(false)
const [saveError, setSaveError] = useState<string | null>(null)
const [dirty, setDirty] = useState(false) const [dirty, setDirty] = useState(false)
const [showCreateForm, setShowCreateForm] = useState(false) const [showCreateForm, setShowCreateForm] = useState(false)
const [newRoleCode, setNewRoleCode] = useState('') const [newRoleCode, setNewRoleCode] = useState('')
@ -962,6 +968,7 @@ function PermissionsPanel() {
const handleSave = async () => { const handleSave = async () => {
setSaving(true) setSaving(true)
setSaveError(null)
try { try {
for (const role of roles) { for (const role of roles) {
const perms = rolePerms.get(role.sn) const perms = rolePerms.get(role.sn)
@ -981,6 +988,7 @@ function PermissionsPanel() {
setDirty(false) setDirty(false)
} catch (err) { } catch (err) {
console.error('권한 저장 실패:', err) console.error('권한 저장 실패:', err)
setSaveError('권한 저장에 실패했습니다. 다시 시도해주세요.')
} finally { } finally {
setSaving(false) setSaving(false)
} }
@ -1096,6 +1104,7 @@ function PermissionsPanel() {
setSelectedRoleSn={setSelectedRoleSn} setSelectedRoleSn={setSelectedRoleSn}
dirty={dirty} dirty={dirty}
saving={saving} saving={saving}
saveError={saveError}
handleSave={handleSave} handleSave={handleSave}
handleToggleExpand={handleToggleExpand} handleToggleExpand={handleToggleExpand}
handleTogglePerm={handleTogglePerm} handleTogglePerm={handleTogglePerm}

파일 보기

@ -51,6 +51,7 @@ export const ADMIN_MENU: AdminMenuItem[] = [
id: 'coast-guard-assets', label: '해경자산', id: 'coast-guard-assets', label: '해경자산',
children: [ children: [
{ id: 'cleanup-equip', label: '방제장비' }, { id: 'cleanup-equip', label: '방제장비' },
{ id: 'asset-upload', label: '자산현행화' },
{ id: 'dispersant-zone', label: '유처리제 제한구역' }, { id: 'dispersant-zone', label: '유처리제 제한구역' },
{ id: 'vessel-materials', label: '방제선 보유자재' }, { id: 'vessel-materials', label: '방제선 보유자재' },
{ id: 'cleanup-resource', label: '방제자원' }, { id: 'cleanup-resource', label: '방제자원' },

파일 보기

@ -192,6 +192,7 @@ export function OilSpillView() {
// 재계산 상태 // 재계산 상태
const [recalcModalOpen, setRecalcModalOpen] = useState(false) const [recalcModalOpen, setRecalcModalOpen] = useState(false)
const [simulationSummary, setSimulationSummary] = useState<SimulationSummary | null>(null) const [simulationSummary, setSimulationSummary] = useState<SimulationSummary | null>(null)
const [summaryByModel, setSummaryByModel] = useState<Record<string, SimulationSummary>>({})
// 오염분석 상태 // 오염분석 상태
const [analysisTab, setAnalysisTab] = useState<'polygon' | 'circle'>('polygon') const [analysisTab, setAnalysisTab] = useState<'polygon' | 'circle'>('polygon')
@ -501,13 +502,14 @@ export function OilSpillView() {
analysis.opendriftStatus === 'completed' || analysis.poseidonStatus === 'completed'; analysis.opendriftStatus === 'completed' || analysis.poseidonStatus === 'completed';
if (hasCompletedModel) { if (hasCompletedModel) {
try { try {
const { trajectory, summary, centerPoints: cp, windDataByModel: wdByModel, hydrDataByModel: hdByModel } = await fetchAnalysisTrajectory(analysis.acdntSn) const { trajectory, summary, centerPoints: cp, windDataByModel: wdByModel, hydrDataByModel: hdByModel, summaryByModel: sbModel } = await fetchAnalysisTrajectory(analysis.acdntSn)
if (trajectory && trajectory.length > 0) { if (trajectory && trajectory.length > 0) {
setOilTrajectory(trajectory) setOilTrajectory(trajectory)
if (summary) setSimulationSummary(summary) if (summary) setSimulationSummary(summary)
setCenterPoints(cp ?? []) setCenterPoints(cp ?? [])
setWindDataByModel(wdByModel ?? {}); setWindDataByModel(wdByModel ?? {});
setHydrDataByModel(hdByModel ?? {}); setHydrDataByModel(hdByModel ?? {});
if (sbModel) setSummaryByModel(sbModel);
if (coord) setBoomLines(generateAIBoomLines(trajectory, coord, algorithmSettings)) if (coord) setBoomLines(generateAIBoomLines(trajectory, coord, algorithmSettings))
setSensitiveResources(DEMO_SENSITIVE_RESOURCES) setSensitiveResources(DEMO_SENSITIVE_RESOURCES)
// incidentCoord가 변경된 경우 flyTo 완료 후 재생, 그렇지 않으면 즉시 재생 // incidentCoord가 변경된 경우 flyTo 완료 후 재생, 그렇지 않으면 즉시 재생
@ -526,6 +528,7 @@ export function OilSpillView() {
// 데모 궤적 생성 (fallback) — stale wind/current 데이터 초기화 // 데모 궤적 생성 (fallback) — stale wind/current 데이터 초기화
setWindDataByModel({}) setWindDataByModel({})
setHydrDataByModel({}) setHydrDataByModel({})
setSummaryByModel({})
const demoTrajectory = generateDemoTrajectory(coord ?? { lat: 37.39, lon: 126.64 }, demoModels, parseInt(analysis.duration) || 48) const demoTrajectory = generateDemoTrajectory(coord ?? { lat: 37.39, lon: 126.64 }, demoModels, parseInt(analysis.duration) || 48)
setOilTrajectory(demoTrajectory) setOilTrajectory(demoTrajectory)
if (coord) setBoomLines(generateAIBoomLines(demoTrajectory, coord, algorithmSettings)) if (coord) setBoomLines(generateAIBoomLines(demoTrajectory, coord, algorithmSettings))
@ -735,6 +738,7 @@ export function OilSpillView() {
let latestCenterPoints: CenterPoint[] = []; let latestCenterPoints: CenterPoint[] = [];
const newWindDataByModel: Record<string, WindPoint[][]> = {}; const newWindDataByModel: Record<string, WindPoint[][]> = {};
const newHydrDataByModel: Record<string, (HydrDataStep | null)[]> = {}; const newHydrDataByModel: Record<string, (HydrDataStep | null)[]> = {};
const newSummaryByModel: Record<string, SimulationSummary> = {};
const errors: string[] = []; const errors: string[] = [];
data.results.forEach(({ model, status, trajectory, summary, centerPoints, windData, hydrData, error }) => { data.results.forEach(({ model, status, trajectory, summary, centerPoints, windData, hydrData, error }) => {
@ -745,8 +749,9 @@ export function OilSpillView() {
if (trajectory) { if (trajectory) {
merged.push(...trajectory.map(p => ({ ...p, model }))); merged.push(...trajectory.map(p => ({ ...p, model })));
} }
if (model === 'OpenDrift' || !latestSummary) { if (summary) {
if (summary) latestSummary = summary; newSummaryByModel[model] = summary;
if (model === 'OpenDrift' || !latestSummary) latestSummary = summary;
} }
if (windData) newWindDataByModel[model] = windData; if (windData) newWindDataByModel[model] = windData;
if (hydrData) newHydrDataByModel[model] = hydrData; if (hydrData) newHydrDataByModel[model] = hydrData;
@ -775,6 +780,7 @@ export function OilSpillView() {
setWindDataByModel(newWindDataByModel); setWindDataByModel(newWindDataByModel);
setHydrDataByModel(newHydrDataByModel); setHydrDataByModel(newHydrDataByModel);
setSummaryByModel(newSummaryByModel);
const booms = generateAIBoomLines(merged, effectiveCoord, algorithmSettings); const booms = generateAIBoomLines(merged, effectiveCoord, algorithmSettings);
setBoomLines(booms); setBoomLines(booms);
setSensitiveResources(DEMO_SENSITIVE_RESOURCES); setSensitiveResources(DEMO_SENSITIVE_RESOURCES);
@ -838,7 +844,13 @@ export function OilSpillView() {
weather: wx weather: wx
? { windDir: wx.wind, windSpeed: wx.wind, waveHeight: wx.wave, temp: wx.temp } ? { windDir: wx.wind, windSpeed: wx.wind, waveHeight: wx.wave, temp: wx.temp }
: null, : null,
spread: { kosps: '—', openDrift: '—', poseidon: '—' }, spread: (() => {
const fmt = (model: string) => {
const s = summaryByModel[model];
return s ? `${s.pollutionArea.toFixed(2)} km²` : '—';
};
return { kosps: fmt('KOSPS'), openDrift: fmt('OpenDrift'), poseidon: fmt('POSEIDON') };
})(),
coastal: { coastal: {
firstTime: (() => { firstTime: (() => {
const beachedTimes = oilTrajectory.filter(p => p.stranded === 1).map(p => p.time); const beachedTimes = oilTrajectory.filter(p => p.stranded === 1).map(p => p.time);

파일 보기

@ -209,6 +209,7 @@ export interface TrajectoryResponse {
centerPoints?: CenterPoint[]; centerPoints?: CenterPoint[];
windDataByModel?: Record<string, WindPoint[][]>; windDataByModel?: Record<string, WindPoint[][]>;
hydrDataByModel?: Record<string, (HydrDataStep | null)[]>; hydrDataByModel?: Record<string, (HydrDataStep | null)[]>;
summaryByModel?: Record<string, SimulationSummary>;
} }
export const fetchAnalysisTrajectory = async (acdntSn: number): Promise<TrajectoryResponse> => { export const fetchAnalysisTrajectory = async (acdntSn: number): Promise<TrajectoryResponse> => {

파일 보기

@ -47,6 +47,7 @@ const OilSpreadMapPanel = ({ mapData, capturedImage, onCapture, onReset }: OilSp
simulationStartTime={mapData.simulationStartTime || undefined} simulationStartTime={mapData.simulationStartTime || undefined}
mapCaptureRef={captureRef} mapCaptureRef={captureRef}
showOverlays={false} showOverlays={false}
lightMode
/> />
{/* 캡처 이미지 오버레이 — 우측 상단 */} {/* 캡처 이미지 오버레이 — 우측 상단 */}

파일 보기

@ -7,9 +7,6 @@ import OilSpreadMapPanel from './OilSpreadMapPanel';
import { saveReport } from '../services/reportsApi'; import { saveReport } from '../services/reportsApi';
import { import {
CATEGORIES, CATEGORIES,
sampleOilData,
sampleHnsData,
sampleRescueData,
type ReportCategory, type ReportCategory,
type ReportSection, type ReportSection,
} from './reportTypes'; } from './reportTypes';
@ -83,8 +80,8 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
report.incident.pollutant = oilPayload.pollution.oilType; report.incident.pollutant = oilPayload.pollution.oilType;
report.incident.spillAmount = oilPayload.pollution.spillAmount; report.incident.spillAmount = oilPayload.pollution.spillAmount;
} else { } else {
report.incident.pollutant = sampleOilData.pollution.oilType; report.incident.pollutant = '';
report.incident.spillAmount = sampleOilData.pollution.spillAmount; report.incident.spillAmount = '';
} }
} }
if (activeCat === 0 && oilMapCaptured) { if (activeCat === 0 && oilMapCaptured) {
@ -102,7 +99,7 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
const handleDownload = () => { const handleDownload = () => {
const secColor = cat.color === 'var(--cyan)' ? '#06b6d4' : cat.color === 'var(--orange)' ? '#f97316' : '#ef4444'; const secColor = cat.color === 'var(--cyan)' ? '#06b6d4' : cat.color === 'var(--orange)' ? '#f97316' : '#ef4444';
const sectionHTML = activeSections.map(sec => { const sectionHTML = activeSections.map(sec => {
let content = `<p style="font-size:12px;color:#666;">${sec.desc}</p>`; let content = `<p style="font-size:12px;color:#999;">—</p>`;
// OIL 섹션에 실 데이터 삽입 // OIL 섹션에 실 데이터 삽입
if (activeCat === 0) { if (activeCat === 0) {
@ -123,6 +120,15 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
content = `${mapImg}<table style="width:100%;border-collapse:collapse;font-size:12px;"><tr>${tds}</tr></table>`; content = `${mapImg}<table style="width:100%;border-collapse:collapse;font-size:12px;"><tr>${tds}</tr></table>`;
} }
} }
if (activeCat === 0 && sec.id === 'oil-coastal') {
if (oilPayload) {
const coastLength = oilPayload.pollution.coastLength;
const hasNoCoastal = !coastLength || coastLength === '—' || coastLength.startsWith('0.00');
content = hasNoCoastal
? `<p style="font-size:12px;">유출유의 해안 부착이 없습니다.</p>`
: `<p style="font-size:12px;">최초 부착시간: <b>${oilPayload.coastal?.firstTime ?? '—'}</b> / 부착 해안길이: <b>${coastLength}</b></p>`;
}
}
if (activeCat === 0 && oilPayload) { if (activeCat === 0 && oilPayload) {
if (sec.id === 'oil-pollution') { if (sec.id === 'oil-pollution') {
const rows = [ const rows = [
@ -322,9 +328,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
/> />
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
{[ {[
{ label: 'KOSPS', value: oilPayload?.spread.kosps || sampleOilData.spread.kosps, color: '#06b6d4' }, { label: 'KOSPS', value: oilPayload?.spread.kosps || '—', color: '#06b6d4' },
{ label: 'OpenDrift', value: oilPayload?.spread.openDrift || sampleOilData.spread.openDrift, color: '#ef4444' }, { label: 'OpenDrift', value: oilPayload?.spread.openDrift || '—', color: '#ef4444' },
{ label: 'POSEIDON', value: oilPayload?.spread.poseidon || sampleOilData.spread.poseidon, color: '#f97316' }, { label: 'POSEIDON', value: oilPayload?.spread.poseidon || '—', color: '#f97316' },
].map((m, i) => ( ].map((m, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
<p className="text-[10px] text-text-3 font-korean mb-1">{m.label}</p> <p className="text-[10px] text-text-3 font-korean mb-1">{m.label}</p>
@ -345,9 +351,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
<colgroup><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /></colgroup> <colgroup><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /><col style={{ width: '25%' }} /></colgroup>
<tbody> <tbody>
{[ {[
['유출량', oilPayload?.pollution.spillAmount || sampleOilData.pollution.spillAmount, '풍화량', oilPayload?.pollution.weathered || sampleOilData.pollution.weathered], ['유출량', oilPayload?.pollution.spillAmount || '—', '풍화량', oilPayload?.pollution.weathered || '—'],
['해상잔유량', oilPayload?.pollution.seaRemain || sampleOilData.pollution.seaRemain, '오염해역면적', oilPayload?.pollution.pollutionArea || sampleOilData.pollution.pollutionArea], ['해상잔유량', oilPayload?.pollution.seaRemain || '—', '오염해역면적', oilPayload?.pollution.pollutionArea || '—'],
['연안부착량', oilPayload?.pollution.coastAttach || sampleOilData.pollution.coastAttach, '오염해안길이', oilPayload?.pollution.coastLength || sampleOilData.pollution.coastLength], ['연안부착량', oilPayload?.pollution.coastAttach || '—', '오염해안길이', oilPayload?.pollution.coastLength || '—'],
].map((row, i) => ( ].map((row, i) => (
<tr key={i} className="border-b border-border"> <tr key={i} className="border-b border-border">
<td className="px-4 py-3 text-[11px] text-text-3 font-korean bg-[rgba(255,255,255,0.02)]">{row[0]}</td> <td className="px-4 py-3 text-[11px] text-text-3 font-korean bg-[rgba(255,255,255,0.02)]">{row[0]}</td>
@ -361,20 +367,20 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
</> </>
)} )}
{sec.id === 'oil-sensitive' && ( {sec.id === 'oil-sensitive' && (
<> <p className="text-[12px] text-text-3 font-korean italic">
<p className="text-[11px] text-text-3 font-korean mb-3"> 10 NM </p> .
<div className="flex flex-wrap gap-2"> </p>
{sampleOilData.sensitive.map((item, i) => (
<span key={i} className="px-3 py-1.5 text-[11px] font-semibold rounded-md bg-bg-3 border border-border text-text-2 font-korean">{item.label}</span>
))}
</div>
</>
)} )}
{sec.id === 'oil-coastal' && (() => { {sec.id === 'oil-coastal' && (() => {
const coastLength = oilPayload?.pollution.coastLength; if (!oilPayload) {
const hasNoCoastal = oilPayload && ( return (
!coastLength || coastLength === '—' || coastLength.startsWith('0.00') <p className="text-[12px] text-text-3 font-korean italic">
); .
</p>
);
}
const coastLength = oilPayload.pollution.coastLength;
const hasNoCoastal = !coastLength || coastLength === '—' || coastLength.startsWith('0.00');
if (hasNoCoastal) { if (hasNoCoastal) {
return ( return (
<p className="text-[12px] text-text-2 font-korean"> <p className="text-[12px] text-text-2 font-korean">
@ -384,9 +390,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
} }
return ( return (
<p className="text-[12px] text-text-2 font-korean"> <p className="text-[12px] text-text-2 font-korean">
: <span className="font-semibold text-text-1">{oilPayload?.coastal?.firstTime ?? sampleOilData.coastal.firstTime}</span> : <span className="font-semibold text-text-1">{oilPayload.coastal?.firstTime ?? '—'}</span>
{' / '} {' / '}
: <span className="font-semibold text-text-1">{coastLength || sampleOilData.coastal.coastLength}</span> : <span className="font-semibold text-text-1">{coastLength}</span>
</p> </p>
); );
})()} })()}
@ -399,20 +405,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
</div> </div>
)} )}
{sec.id === 'oil-tide' && ( {sec.id === 'oil-tide' && (
<> <p className="text-[12px] text-text-3 font-korean italic">
<p className="text-[12px] text-text-2 font-korean"> · .
: <span className="font-semibold text-text-1">{sampleOilData.tide.highTide1}</span> </p>
{' / '}: <span className="font-semibold text-text-1">{sampleOilData.tide.lowTide}</span>
{' / '}: <span className="font-semibold text-text-1">{sampleOilData.tide.highTide2}</span>
</p>
{oilPayload?.weather && (
<p className="text-[11px] text-text-3 font-korean mt-2">
기상: 풍향/ <span className="text-text-2 font-semibold">{oilPayload.weather.windDir}</span>
{' / '} <span className="text-text-2 font-semibold">{oilPayload.weather.waveHeight}</span>
{' / '} <span className="text-text-2 font-semibold">{oilPayload.weather.temp}</span>
</p>
)}
</>
)} )}
{/* ── HNS 대기확산 섹션들 ── */} {/* ── HNS 대기확산 섹션들 ── */}
@ -432,7 +427,7 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
)} )}
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
{[ {[
{ label: hnsPayload?.atm.model || 'ALOHA', value: hnsPayload?.atm.maxDistance || sampleHnsData.atm.aloha, color: '#f97316', desc: '최대 확산거리' }, { label: hnsPayload?.atm.model || 'ALOHA', value: hnsPayload?.atm.maxDistance || '—', color: '#f97316', desc: '최대 확산거리' },
{ label: '최대 농도', value: hnsPayload?.maxConcentration || '—', color: '#ef4444', desc: '지상 1.5m 기준' }, { label: '최대 농도', value: hnsPayload?.maxConcentration || '—', color: '#ef4444', desc: '지상 1.5m 기준' },
{ label: 'AEGL-1 면적', value: hnsPayload?.aeglAreas.aegl1 || '—', color: '#06b6d4', desc: '확산 영향 면적' }, { label: 'AEGL-1 면적', value: hnsPayload?.aeglAreas.aegl1 || '—', color: '#06b6d4', desc: '확산 영향 면적' },
].map((m, i) => ( ].map((m, i) => (
@ -448,9 +443,9 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
{sec.id === 'hns-hazard' && ( {sec.id === 'hns-hazard' && (
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
{[ {[
{ label: 'AEGL-3 구역', value: hnsPayload?.hazard.aegl3 || sampleHnsData.hazard.erpg3, area: hnsPayload?.aeglAreas.aegl3, color: '#ef4444', desc: '생명 위협' }, { label: 'AEGL-3 구역', value: hnsPayload?.hazard.aegl3 || '—', area: hnsPayload?.aeglAreas.aegl3, color: '#ef4444', desc: '생명 위협' },
{ label: 'AEGL-2 구역', value: hnsPayload?.hazard.aegl2 || sampleHnsData.hazard.erpg2, area: hnsPayload?.aeglAreas.aegl2, color: '#f97316', desc: '건강 피해' }, { label: 'AEGL-2 구역', value: hnsPayload?.hazard.aegl2 || '—', area: hnsPayload?.aeglAreas.aegl2, color: '#f97316', desc: '건강 피해' },
{ label: 'AEGL-1 구역', value: hnsPayload?.hazard.aegl1 || sampleHnsData.hazard.evacuation, area: hnsPayload?.aeglAreas.aegl1, color: '#eab308', desc: '불쾌감' }, { label: 'AEGL-1 구역', value: hnsPayload?.hazard.aegl1 || '—', area: hnsPayload?.aeglAreas.aegl1, color: '#eab308', desc: '불쾌감' },
].map((h, i) => ( ].map((h, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
<p className="text-[9px] font-bold font-korean mb-1" style={{ color: h.color }}>{h.label}</p> <p className="text-[9px] font-bold font-korean mb-1" style={{ color: h.color }}>{h.label}</p>
@ -464,10 +459,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
{sec.id === 'hns-substance' && ( {sec.id === 'hns-substance' && (
<div className="grid grid-cols-2 gap-2 text-[11px]"> <div className="grid grid-cols-2 gap-2 text-[11px]">
{[ {[
{ k: '물질명', v: hnsPayload?.substance.name || sampleHnsData.substance.name }, { k: '물질명', v: hnsPayload?.substance.name || '—' },
{ k: 'UN번호', v: hnsPayload?.substance.un || sampleHnsData.substance.un }, { k: 'UN번호', v: hnsPayload?.substance.un || '—' },
{ k: 'CAS번호', v: hnsPayload?.substance.cas || sampleHnsData.substance.cas }, { k: 'CAS번호', v: hnsPayload?.substance.cas || '—' },
{ k: '위험등급', v: hnsPayload?.substance.class || sampleHnsData.substance.class }, { k: '위험등급', v: hnsPayload?.substance.class || '—' },
].map((r, i) => ( ].map((r, i) => (
<div key={i} className="flex justify-between px-3 py-2 bg-bg-1 rounded border border-border"> <div key={i} className="flex justify-between px-3 py-2 bg-bg-1 rounded border border-border">
<span className="text-text-3 font-korean">{r.k}</span> <span className="text-text-3 font-korean">{r.k}</span>
@ -476,25 +471,21 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
))} ))}
<div className="col-span-2 flex justify-between px-3 py-2 bg-bg-1 rounded border border-[rgba(239,68,68,0.3)]"> <div className="col-span-2 flex justify-between px-3 py-2 bg-bg-1 rounded border border-[rgba(239,68,68,0.3)]">
<span className="text-text-3 font-korean"></span> <span className="text-text-3 font-korean"></span>
<span className="text-[var(--red)] font-semibold font-mono text-[10px]">{hnsPayload?.substance.toxicity || sampleHnsData.substance.toxicity}</span> <span className="text-[var(--red)] font-semibold font-mono text-[10px]">{hnsPayload?.substance.toxicity || '—'}</span>
</div> </div>
</div> </div>
)} )}
{sec.id === 'hns-ppe' && ( {sec.id === 'hns-ppe' && (
<div className="flex flex-wrap gap-2"> <div className="flex flex-wrap gap-2">
{sampleHnsData.ppe.map((item, i) => ( <span className="text-text-3 font-korean text-[11px]"></span>
<span key={i} className="px-3 py-1.5 text-[11px] font-semibold rounded-md border text-text-2 font-korean" style={{ background: 'rgba(249,115,22,0.06)', borderColor: 'rgba(249,115,22,0.2)' }}>
🛡 {item}
</span>
))}
</div> </div>
)} )}
{sec.id === 'hns-facility' && ( {sec.id === 'hns-facility' && (
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
{[ {[
{ label: '인근 학교', value: `${sampleHnsData.facility.schools}개소`, icon: '🏫' }, { label: '인근 학교', value: '—', icon: '🏫' },
{ label: '의료시설', value: `${sampleHnsData.facility.hospitals}개소`, icon: '🏥' }, { label: '의료시설', value: '—', icon: '🏥' },
{ label: '주변 인구', value: sampleHnsData.facility.population, icon: '👥' }, { label: '주변 인구', value: '—', icon: '👥' },
].map((f, i) => ( ].map((f, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
<div className="text-[18px] mb-1">{f.icon}</div> <div className="text-[18px] mb-1">{f.icon}</div>
@ -512,10 +503,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
{sec.id === 'hns-weather' && ( {sec.id === 'hns-weather' && (
<div className="grid grid-cols-4 gap-3"> <div className="grid grid-cols-4 gap-3">
{[ {[
{ label: '풍향', value: hnsPayload?.weather.windDir || 'NE 42°', icon: '🌬' }, { label: '풍향', value: hnsPayload?.weather.windDir || '', icon: '🌬' },
{ label: '풍속', value: hnsPayload?.weather.windSpeed || '5.2 m/s', icon: '💨' }, { label: '풍속', value: hnsPayload?.weather.windSpeed || '', icon: '💨' },
{ label: '대기안정도', value: hnsPayload?.weather.stability || 'D (중립)', icon: '🌡' }, { label: '대기안정도', value: hnsPayload?.weather.stability || '', icon: '🌡' },
{ label: '기온', value: hnsPayload?.weather.temperature || '8.5°C', icon: '☀️' }, { label: '기온', value: hnsPayload?.weather.temperature || '', icon: '☀️' },
].map((w, i) => ( ].map((w, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
<div className="text-[16px] mb-0.5">{w.icon}</div> <div className="text-[16px] mb-0.5">{w.icon}</div>
@ -530,10 +521,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
{sec.id === 'rescue-safety' && ( {sec.id === 'rescue-safety' && (
<div className="grid grid-cols-4 gap-3"> <div className="grid grid-cols-4 gap-3">
{[ {[
{ label: 'GM (복원력)', value: sampleRescueData.safety.gm, color: '#f97316' }, { label: 'GM (복원력)', value: '—', color: '#f97316' },
{ label: '경사각 (Heel)', value: sampleRescueData.safety.heel, color: '#ef4444' }, { label: '경사각 (Heel)', value: '—', color: '#ef4444' },
{ label: '트림 (Trim)', value: sampleRescueData.safety.trim, color: '#06b6d4' }, { label: '트림 (Trim)', value: '—', color: '#06b6d4' },
{ label: '안전 상태', value: sampleRescueData.safety.status, color: '#f97316' }, { label: '안전 상태', value: '—', color: '#f97316' },
].map((s, i) => ( ].map((s, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
<p className="text-[9px] text-text-3 font-korean mb-1">{s.label}</p> <p className="text-[9px] text-text-3 font-korean mb-1">{s.label}</p>
@ -544,26 +535,18 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
)} )}
{sec.id === 'rescue-timeline' && ( {sec.id === 'rescue-timeline' && (
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
{[ <div className="flex items-center gap-3 px-3 py-2 bg-bg-1 rounded border border-border">
{ time: '06:28', event: '충돌 발생 — ORIENTAL GLORY ↔ HAI FENG 168', color: '#ef4444' }, <span className="text-[11px] text-text-3 font-korean"></span>
{ time: '06:30', event: 'No.1P 탱크 파공, 벙커C유 유출 개시', color: '#f97316' }, </div>
{ time: '06:35', event: 'VHF Ch.16 조난통신, 해경 출동 요청', color: '#eab308' },
{ time: '07:15', event: '해경 3009함 현장 도착, 방제 개시', color: '#06b6d4' },
].map((e, i) => (
<div key={i} className="flex items-center gap-3 px-3 py-2 bg-bg-1 rounded border border-border">
<span className="font-mono text-[11px] font-bold min-w-[40px]" style={{ color: e.color }}>{e.time}</span>
<span className="text-[11px] text-text-2 font-korean">{e.event}</span>
</div>
))}
</div> </div>
)} )}
{sec.id === 'rescue-casualty' && ( {sec.id === 'rescue-casualty' && (
<div className="grid grid-cols-4 gap-3"> <div className="grid grid-cols-4 gap-3">
{[ {[
{ label: '총원', value: sampleRescueData.casualty.total }, { label: '총원', value: '—' },
{ label: '구조완료', value: sampleRescueData.casualty.rescued, color: '#22c55e' }, { label: '구조완료', value: '—', color: '#22c55e' },
{ label: '실종', value: sampleRescueData.casualty.missing, color: '#ef4444' }, { label: '실종', value: '—', color: '#ef4444' },
{ label: '부상', value: sampleRescueData.casualty.injured, color: '#f97316' }, { label: '부상', value: '—', color: '#f97316' },
].map((c, i) => ( ].map((c, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
<p className="text-[9px] text-text-3 font-korean mb-1">{c.label}</p> <p className="text-[9px] text-text-3 font-korean mb-1">{c.label}</p>
@ -584,30 +567,18 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{sampleRescueData.resources.map((r, i) => ( <tr className="border-b border-border">
<tr key={i} className="border-b border-border"> <td colSpan={4} className="px-3 py-3 text-center text-text-3 font-korean text-[11px]"></td>
<td className="px-3 py-2 text-text-2 font-korean">{r.type}</td> </tr>
<td className="px-3 py-2 text-text-1 font-mono font-semibold">{r.name}</td>
<td className="px-3 py-2 text-text-2 text-center font-mono">{r.eta}</td>
<td className="px-3 py-2 text-center">
<span className="px-2 py-0.5 rounded text-[10px] font-semibold font-korean" style={{
background: r.status === '투입중' ? 'rgba(34,197,94,0.15)' : r.status === '이동중' ? 'rgba(249,115,22,0.15)' : 'rgba(138,150,168,0.15)',
color: r.status === '투입중' ? '#22c55e' : r.status === '이동중' ? '#f97316' : '#8a96a8',
}}>
{r.status}
</span>
</td>
</tr>
))}
</tbody> </tbody>
</table> </table>
)} )}
{sec.id === 'rescue-grounding' && ( {sec.id === 'rescue-grounding' && (
<div className="grid grid-cols-3 gap-3"> <div className="grid grid-cols-3 gap-3">
{[ {[
{ label: '좌초 위험도', value: sampleRescueData.grounding.risk, color: '#ef4444' }, { label: '좌초 위험도', value: '—', color: '#ef4444' },
{ label: '최근 천해', value: sampleRescueData.grounding.nearestShallow, color: '#f97316' }, { label: '최근 천해', value: '—', color: '#f97316' },
{ label: '현재 수심', value: sampleRescueData.grounding.depth, color: '#06b6d4' }, { label: '현재 수심', value: '—', color: '#06b6d4' },
].map((g, i) => ( ].map((g, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-4 text-center">
<p className="text-[9px] text-text-3 font-korean mb-1">{g.label}</p> <p className="text-[9px] text-text-3 font-korean mb-1">{g.label}</p>
@ -619,10 +590,10 @@ function ReportGenerator({ onSave }: ReportGeneratorProps) {
{sec.id === 'rescue-weather' && ( {sec.id === 'rescue-weather' && (
<div className="grid grid-cols-4 gap-3"> <div className="grid grid-cols-4 gap-3">
{[ {[
{ label: '파고', value: '1.5 m', icon: '🌊' }, { label: '파고', value: '', icon: '🌊' },
{ label: '풍속', value: '5.2 m/s', icon: '🌬' }, { label: '풍속', value: '', icon: '🌬' },
{ label: '조류', value: '1.2 kts NE', icon: '🌀' }, { label: '조류', value: '', icon: '🌀' },
{ label: '시정', value: '8 km', icon: '👁' }, { label: '시정', value: '', icon: '👁' },
].map((w, i) => ( ].map((w, i) => (
<div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center"> <div key={i} className="bg-bg-1 border border-border rounded-lg p-3 text-center">
<div className="text-[16px] mb-0.5">{w.icon}</div> <div className="text-[16px] mb-0.5">{w.icon}</div>

파일 보기

@ -11,6 +11,7 @@ import {
generateReportHTML, generateReportHTML,
exportAsPDF, exportAsPDF,
exportAsHWP, exportAsHWP,
buildReportGetVal,
typeColors, typeColors,
statusColors, statusColors,
analysisCatColors, analysisCatColors,
@ -284,16 +285,7 @@ export function ReportsView() {
onClick={() => { onClick={() => {
const tpl = templateTypes.find(t => t.id === previewReport.reportType) const tpl = templateTypes.find(t => t.id === previewReport.reportType)
if (tpl) { if (tpl) {
const getVal = (key: string) => { const getVal = buildReportGetVal(previewReport)
if (key === 'author') return previewReport.author
if (key.startsWith('incident.')) {
const f = key.split('.')[1]
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (previewReport.incident as any)[f] || ''
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (previewReport as any)[key] || ''
}
const html = generateReportHTML(tpl.label, { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction }, tpl.sections, getVal) const html = generateReportHTML(tpl.label, { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction }, tpl.sections, getVal)
exportAsPDF(html, previewReport.title || tpl.label) exportAsPDF(html, previewReport.title || tpl.label)
} }
@ -307,16 +299,7 @@ export function ReportsView() {
onClick={() => { onClick={() => {
const tpl = templateTypes.find(t => t.id === previewReport.reportType) as TemplateType | undefined const tpl = templateTypes.find(t => t.id === previewReport.reportType) as TemplateType | undefined
if (tpl) { if (tpl) {
const getVal = (key: string) => { const getVal = buildReportGetVal(previewReport)
if (key === 'author') return previewReport.author
if (key.startsWith('incident.')) {
const f = key.split('.')[1]
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (previewReport.incident as any)[f] || ''
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (previewReport as any)[key] || ''
}
const meta = { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction } const meta = { writeTime: previewReport.incident.writeTime, author: previewReport.author, jurisdiction: previewReport.jurisdiction }
const filename = previewReport.title || tpl.label const filename = previewReport.title || tpl.label
exportAsHWP(tpl.label, meta, tpl.sections, getVal, filename) exportAsHWP(tpl.label, meta, tpl.sections, getVal, filename)

파일 보기

@ -66,24 +66,23 @@ export const templateTypes: TemplateType[] = [
sections: [ sections: [
{ title: '1. 기본정보', fields: [ { title: '1. 기본정보', fields: [
{ key: 'incident.writeTime', label: '보고일시', type: 'text' }, { key: 'incident.writeTime', label: '보고일시', type: 'text' },
{ key: 'author', label: '작성자', type: 'text' }, { key: 'author', label: '작성자', type: 'text' },
]}, ]},
{ title: '2. 사고개요', fields: [ { title: '2. 사고개요', fields: [
{ key: 'incident.name', label: '사고명', type: 'text' }, { key: 'incident.name', label: '사고명', type: 'text' },
{ key: 'incident.occurTime', label: '발생일시', type: 'text' }, { key: 'incident.occurTime', label: '발생일시', type: 'text' },
{ key: 'incident.location', label: '발생위치', type: 'text' }, { key: 'incident.location', label: '발생위치', type: 'text' },
{ key: 'incident.pollutant', label: '유출유종', type: 'text' }, { key: 'incident.pollutant', label: '유출유종', type: 'text' },
{ key: 'incident.spillAmount', label: '유출량', type: 'text' }, { key: 'incident.spillAmount', label: '유출량', type: 'text' },
]},
{ title: '3. 해양기상 현황', fields: [
{ key: 'weatherSummary', label: '', type: 'textarea' },
]},
{ title: '4. 확산예측 결과', fields: [
{ key: 'spreadResult', label: '', type: 'textarea' },
]},
{ title: '5. 민감자원 영향', fields: [
{ key: 'sensitiveImpact', label: '', type: 'textarea' },
]}, ]},
{ title: '3. 조석 현황', fields: [{ key: '__tide', label: '', type: 'textarea' }] },
{ title: '4. 해양기상 현황', fields: [{ key: '__weather', label: '', type: 'textarea' }] },
{ title: '5. 확산예측 결과', fields: [{ key: '__spread', label: '', type: 'textarea' }] },
{ title: '6. 민감자원 현황', fields: [{ key: '__sensitive', label: '', type: 'textarea' }] },
{ title: '7. 방제 자원', fields: [{ key: '__vessels', label: '', type: 'textarea' }] },
{ title: '8. 회수·처리 현황',fields: [{ key: '__recovery', label: '', type: 'textarea' }] },
{ title: '9. 최종 결과', fields: [{ key: '__result', label: '', type: 'textarea' }] },
{ title: '10. 분석 의견', fields: [{ key: 'analysis', label: '', type: 'textarea' }] },
] ]
}, },
{ {

파일 보기

@ -86,3 +86,122 @@ export function inferAnalysisCategory(report: OilSpillReportData): string {
if (t.includes('유출유') || t.includes('확산예측') || t.includes('민감자원') || t.includes('유출사고') || t.includes('오염') || t.includes('방제') || rt === '유출유 보고' || rt === '예측보고서') return '유출유 확산예측' if (t.includes('유출유') || t.includes('확산예측') || t.includes('민감자원') || t.includes('유출사고') || t.includes('오염') || t.includes('방제') || rt === '유출유 보고' || rt === '예측보고서') return '유출유 확산예측'
return '' return ''
} }
// ─── PDF/HWP 섹션 포맷 헬퍼 ─────────────────────────────────
const TH = 'padding:6px 8px;border:1px solid #d1d5db;background:#f0f4f8;font-weight:600;font-size:11px;'
const TD = 'padding:6px 8px;border:1px solid #d1d5db;font-size:11px;'
const TABLE = 'width:100%;border-collapse:collapse;'
function formatTideTable(tide: OilSpillReportData['tide']): string {
if (!tide?.length) return ''
const header = `<tr><th style="${TH}">날짜</th><th style="${TH}">조형</th><th style="${TH}">간조1</th><th style="${TH}">만조1</th><th style="${TH}">간조2</th><th style="${TH}">만조2</th></tr>`
const rows = tide.map(t =>
`<tr><td style="${TD}">${t.date}</td><td style="${TD}">${t.tideType}</td><td style="${TD}">${t.lowTide1}</td><td style="${TD}">${t.highTide1}</td><td style="${TD}">${t.lowTide2}</td><td style="${TD}">${t.highTide2}</td></tr>`
).join('')
return `<table style="${TABLE}">${header}${rows}</table>`
}
function formatWeatherTable(weather: OilSpillReportData['weather']): string {
if (!weather?.length) return ''
const header = `<tr><th style="${TH}">시각</th><th style="${TH}">풍향</th><th style="${TH}">풍속</th><th style="${TH}">유향</th><th style="${TH}">유속</th><th style="${TH}">파고</th></tr>`
const rows = weather.map(w =>
`<tr><td style="${TD}">${w.time}</td><td style="${TD}">${w.windDir}</td><td style="${TD}">${w.windSpeed}</td><td style="${TD}">${w.currentDir}</td><td style="${TD}">${w.currentSpeed}</td><td style="${TD}">${w.waveHeight}</td></tr>`
).join('')
return `<table style="${TABLE}">${header}${rows}</table>`
}
function formatSpreadTable(spread: OilSpillReportData['spread']): string {
if (!spread?.length) return ''
const header = `<tr><th style="${TH}">경과시간</th><th style="${TH}">풍화량</th><th style="${TH}">해상잔유량</th><th style="${TH}">연안부착량</th><th style="${TH}">면적</th></tr>`
const rows = spread.map(s =>
`<tr><td style="${TD}">${s.elapsed}</td><td style="${TD}">${s.weathered}</td><td style="${TD}">${s.seaRemain}</td><td style="${TD}">${s.coastAttach}</td><td style="${TD}">${s.area}</td></tr>`
).join('')
return `<table style="${TABLE}">${header}${rows}</table>`
}
function formatSensitiveTable(r: OilSpillReportData): string {
const parts: string[] = []
if (r.aquaculture?.length) {
const h = `<tr><th style="${TH}">종류</th><th style="${TH}">면적</th><th style="${TH}">거리</th></tr>`
const rows = r.aquaculture.map(a => `<tr><td style="${TD}">${a.type}</td><td style="${TD}">${a.area}</td><td style="${TD}">${a.distance}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">양식업</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.beaches?.length) {
const h = `<tr><th style="${TH}">해수욕장명</th><th style="${TH}">거리</th></tr>`
const rows = r.beaches.map(b => `<tr><td style="${TD}">${b.name}</td><td style="${TD}">${b.distance}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">해수욕장</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.markets?.length) {
const h = `<tr><th style="${TH}">수산시장명</th><th style="${TH}">거리</th></tr>`
const rows = r.markets.map(m => `<tr><td style="${TD}">${m.name}</td><td style="${TD}">${m.distance}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">수산시장</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.esi?.length) {
const h = `<tr><th style="${TH}">코드</th><th style="${TH}">유형</th><th style="${TH}">길이</th></tr>`
const rows = r.esi.map(e => `<tr><td style="${TD}">${e.code}</td><td style="${TD}">${e.type}</td><td style="${TD}">${e.length}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">ESI 해안선</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.species?.length) {
const h = `<tr><th style="${TH}">분류</th><th style="${TH}">종명</th></tr>`
const rows = r.species.map(s => `<tr><td style="${TD}">${s.category}</td><td style="${TD}">${s.species}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">보호생물</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.habitat?.length) {
const h = `<tr><th style="${TH}">유형</th><th style="${TH}">면적</th></tr>`
const rows = r.habitat.map(h2 => `<tr><td style="${TD}">${h2.type}</td><td style="${TD}">${h2.area}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">서식지</p><table style="${TABLE}">${h}${rows}</table>`)
}
if (r.sensitivity?.length) {
const h = `<tr><th style="${TH}">민감도</th><th style="${TH}">면적</th></tr>`
const rows = r.sensitivity.map(s => `<tr><td style="${TD}">${s.level}</td><td style="${TD}">${s.area}</td></tr>`).join('')
parts.push(`<p style="font-size:11px;font-weight:600;margin:8px 0 4px;">민감도 등급</p><table style="${TABLE}">${h}${rows}</table>`)
}
return parts.join('')
}
function formatVesselsTable(vessels: OilSpillReportData['vessels']): string {
if (!vessels?.length) return ''
const header = `<tr><th style="${TH}">선명</th><th style="${TH}">기관</th><th style="${TH}">거리</th><th style="${TH}">속력</th><th style="${TH}">톤수</th><th style="${TH}">회수장비</th><th style="${TH}">오일붐</th></tr>`
const rows = vessels.map(v =>
`<tr><td style="${TD}">${v.name}</td><td style="${TD}">${v.org}</td><td style="${TD}">${v.dist}</td><td style="${TD}">${v.speed}</td><td style="${TD}">${v.ton}</td><td style="${TD}">${v.collectorType} ${v.collectorCap}</td><td style="${TD}">${v.boomType} ${v.boomLength}</td></tr>`
).join('')
return `<table style="${TABLE}">${header}${rows}</table>`
}
function formatRecoveryTable(recovery: OilSpillReportData['recovery']): string {
if (!recovery?.length) return ''
const header = `<tr><th style="${TH}">선박명</th><th style="${TH}">회수 기간</th></tr>`
const rows = recovery.map(r =>
`<tr><td style="${TD}">${r.shipName}</td><td style="${TD}">${r.period}</td></tr>`
).join('')
return `<table style="${TABLE}">${header}${rows}</table>`
}
function formatResultTable(result: OilSpillReportData['result']): string {
if (!result) return ''
return `<table style="${TABLE}">
<tr><td style="${TH}"></td><td style="${TD}">${result.spillTotal}</td><td style="${TH}"></td><td style="${TD}">${result.weatheredTotal}</td></tr>
<tr><td style="${TH}"></td><td style="${TD}">${result.recoveredTotal}</td><td style="${TH}"></td><td style="${TD}">${result.seaRemainTotal}</td></tr>
<tr><td style="${TH}"></td><td style="${TD}" colspan="3">${result.coastAttachTotal}</td></tr>
</table>`
}
export function buildReportGetVal(report: OilSpillReportData) {
return (key: string): string => {
if (key === 'author') return report.author ?? ''
if (key.startsWith('incident.')) {
const f = key.split('.')[1]
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (report.incident as any)[f] || ''
}
if (key === '__tide') return formatTideTable(report.tide)
if (key === '__weather') return formatWeatherTable(report.weather)
if (key === '__spread') return formatSpreadTable(report.spread)
if (key === '__sensitive') return formatSensitiveTable(report)
if (key === '__vessels') return formatVesselsTable(report.vessels)
if (key === '__recovery') return formatRecoveryTable(report.recovery)
if (key === '__result') return formatResultTable(report.result)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (report as any)[key] || ''
}
}

파일 보기

@ -1,13 +0,0 @@
__pycache__/
stitch/
mx15hdi/Detect/Mask_result/
mx15hdi/Detect/result/
mx15hdi/Georeference/Mask_Tif/
mx15hdi/Georeference/Tif/
mx15hdi/Metadata/CSV/
mx15hdi/Metadata/Image/Original_Images/
mx15hdi/Polygon/Shp/

파일 보기

@ -1,376 +0,0 @@
# wing-image-analysis Docker 사용 가이드
드론 영상 기반 유류 오염 분석 FastAPI 서버를 Docker 컨테이너로 빌드하고 실행하는 방법을 설명한다.
---
## 목차
1. [사전 요구사항](#1-사전-요구사항)
2. [빠른 시작](#2-빠른-시작)
3. [빌드 명령어](#3-빌드-명령어)
4. [실행 명령어](#4-실행-명령어)
5. [환경변수 설정](#5-환경변수-설정)
6. [볼륨 구조](#6-볼륨-구조)
7. [API 엔드포인트 사용 예시](#7-api-엔드포인트-사용-예시)
8. [로그 확인 및 디버깅](#8-로그-확인-및-디버깅)
9. [컨테이너 관리](#9-컨테이너-관리)
10. [주의사항](#10-주의사항)
11. [CPU 전용 환경 실행](#11-cpu-전용-환경-실행)
---
## 1. 사전 요구사항
| 항목 | 최소 버전 | 확인 명령어 |
|------|----------|-------------|
| Docker Engine | 24.0 이상 | `docker --version` |
| Docker Compose | 2.20 이상 | `docker compose version` |
| NVIDIA 드라이버 | 525 이상 (CUDA 12.1 지원) | `nvidia-smi` |
| nvidia-container-toolkit | 최신 | `nvidia-ctk --version` |
### nvidia-container-toolkit 설치 (Ubuntu 기준)
```bash
# GPG 키 및 저장소 추가
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \
| sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \
| sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \
| sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit
# Docker 런타임 설정 및 재시작
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker
```
### GPU 동작 확인
```bash
docker run --rm --gpus all nvidia/cuda:12.1-base-ubuntu22.04 nvidia-smi
```
---
## 2. 빠른 시작
```bash
# 1. prediction/image/ 디렉토리로 이동
cd prediction/image
# 2. 환경변수 파일 준비 (필요 시)
cp .env.example .env
# 3. 빌드 + 실행 (백그라운드)
docker compose up -d --build
# 4. 서버 상태 확인
curl http://localhost:5001/docs
```
---
## 3. 빌드 명령어
### docker compose (권장)
```bash
# 이미지 빌드만 수행 (실행 안 함)
docker compose build
# 빌드 로그를 상세하게 출력
docker compose build --progress=plain
# 캐시 없이 처음부터 빌드 (의존성 변경 시)
docker compose build --no-cache
```
### docker build (단독)
```bash
# prediction/image/ 디렉토리에서 실행
docker build -t wing-image-analysis:latest .
# 빌드 태그 지정
docker build -t wing-image-analysis:1.0.0 .
# 캐시 없이 빌드
docker build --no-cache -t wing-image-analysis:latest .
```
> **참고**: 첫 빌드는 PyTorch base 이미지(약 8GB) + GDAL/Python 패키지 설치로 **30~60분** 소요될 수 있다.
> 이후 빌드는 레이어 캐시로 수 분 내 완료된다.
---
## 4. 실행 명령어
### docker compose (권장)
```bash
# 백그라운드 실행
docker compose up -d
# 빌드 후 즉시 실행
docker compose up -d --build
# 포그라운드 실행 (로그 바로 출력)
docker compose up
# 중지
docker compose down
# 중지 + 볼륨 삭제 (데이터 초기화 시)
docker compose down -v
```
### docker run (단독 — 테스트용)
```bash
docker run --rm \
--gpus all \
-p 5001:5001 \
--env-file .env \
-v "$(pwd)/mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images" \
wing-image-analysis:latest
```
---
## 5. 환경변수 설정
`.env.example`을 복사하여 `.env`를 생성한다.
```bash
cp .env.example .env
```
| 변수 | 설명 | 기본값 |
|------|------|--------|
| `API_HOST` | 서버 바인드 주소 | `0.0.0.0` |
| `API_PORT` | 서버 포트 | `5001` |
---
## 6. 볼륨 구조
컨테이너 내부 경로와 호스트 경로의 매핑이다. 이미지/결과 데이터는 컨테이너 외부에 저장되어 컨테이너를 재시작해도 유지된다.
```
호스트 (prediction/image/) 컨테이너 (/app/)
─────────────────────────────────────────────────────────────────────
mx15hdi/Metadata/Image/Original_Images/ → mx15hdi/Metadata/Image/Original_Images/ ← 원본 이미지 입력
mx15hdi/Metadata/CSV/ → mx15hdi/Metadata/CSV/ ← 메타데이터 출력
mx15hdi/Georeference/Tif/ → mx15hdi/Georeference/Tif/ ← GeoTIFF 출력
mx15hdi/Georeference/Mask_Tif/ → mx15hdi/Georeference/Mask_Tif/ ← 마스크 GeoTIFF
mx15hdi/Polygon/Shp/ → mx15hdi/Polygon/Shp/ ← Shapefile 출력
mx15hdi/Detect/result/ → mx15hdi/Detect/result/ ← 블렌딩 결과
mx15hdi/Detect/Mask_result/ → mx15hdi/Detect/Mask_result/ ← 마스크 결과
starsafire/Metadata/Image/Original_Images → starsafire/Metadata/Image/Original_Images ← 열화상 입력
starsafire/{기타}/ → starsafire/{기타}/ ← 열화상 출력
stitch/ → stitch/ ← 스티칭 결과
```
---
## 7. API 엔드포인트 사용 예시
서버 기동 후 `http://localhost:5001/docs`에서 Swagger UI로 전체 API를 확인할 수 있다.
### 7.1 전체 분석 파이프라인 실행
```bash
curl -X POST http://localhost:5001/run-script/ \
-F "files=@/path/to/drone_image.jpg" \
-F "camTy=mx15hdi" \
-F "fileId=20240310_001"
```
**응답 예시**:
```json
{
"meta": "drone_image.jpg,37,30,0,126,55,0,...",
"data": [
{
"classId": 2,
"area": 1234.56,
"volume": 0.1234,
"note": "갈색",
"thickness": 0.0001,
"wkt": "POLYGON((...))"
}
]
}
```
### 7.2 메타데이터 조회
```bash
curl http://localhost:5001/get-metadata/mx15hdi/20240310_001
```
### 7.3 원본 이미지 조회 (Base64)
```bash
curl http://localhost:5001/get-original-image/mx15hdi/20240310_001
```
### 7.4 GeoTIFF + 좌표 조회
```bash
curl http://localhost:5001/get-image/mx15hdi/20240310_001
```
### 7.5 이미지 스티칭
```bash
curl -X POST http://localhost:5001/stitch \
-F "files=@photo1.jpg" \
-F "files=@photo2.jpg" \
-F "mode=drone"
```
---
## 8. 로그 확인 및 디버깅
```bash
# 실시간 로그 출력
docker logs wing-image-analysis -f
# 최근 100줄만 출력
docker logs wing-image-analysis --tail 100
# 컨테이너 내부 쉘 접속
docker exec -it wing-image-analysis bash
# GPU 사용 현황 확인 (컨테이너 내부)
docker exec wing-image-analysis nvidia-smi
# Python 패키지 목록 확인
docker exec wing-image-analysis pip list
```
---
## 9. 컨테이너 관리
```bash
# 상태 확인
docker compose ps
# 재시작
docker compose restart
# 중지 (볼륨 유지)
docker compose down
# 이미지 삭제
docker rmi wing-image-analysis:latest
# 사용하지 않는 리소스 정리
docker system prune -f
```
---
## 10. 주의사항
### GPU 자동 감지
- 서버 기동 시 `torch.cuda.is_available()`로 GPU 유무를 자동 감지한다.
- GPU가 있으면 `cuda:0`, 없으면 `cpu`로 자동 폴백된다.
- 환경변수 `DEVICE`로 device를 명시 지정할 수 있다 (예: `DEVICE=cpu`, `DEVICE=cuda:1`).
### 첫 기동 시간
- AI 모델 로드: 약 **10~30초** 소요 (GPU 메모리에 로딩)
- 준비 완료 후 로그에 `Application startup complete` 메시지가 출력된다.
### workers=1 고정
- GPU 모델은 프로세스 간 공유가 불가하므로 uvicorn workers는 반드시 `1`로 유지해야 한다.
- 병렬 처리는 내부 `ThreadPoolExecutor`(max_workers=4)로 처리된다.
### 포트 충돌
- 기본 포트 `5001`이 다른 서비스와 충돌하면 `docker-compose.yml``ports` 항목을 수정한다:
```yaml
ports:
- "5002:5001" # 호스트 5002 → 컨테이너 5001
```
---
## 11. CPU 전용 환경 실행
GPU(NVIDIA)가 없는 환경에서는 CPU 전용 설정을 사용한다.
### 사전 요구사항 (CPU 모드)
| 항목 | 최소 버전 | 확인 명령어 |
|------|----------|-------------|
| Docker Engine | 24.0 이상 | `docker --version` |
| Docker Compose | 2.20 이상 | `docker compose version` |
| NVIDIA 드라이버 | **불필요** | — |
### 빠른 시작 (CPU)
```bash
# prediction/image/ 디렉토리로 이동
cd prediction/image
# 환경변수 파일 준비 (필요 시)
cp .env.example .env
# CPU 이미지 빌드 + 실행
docker compose -f docker-compose.cpu.yml up -d --build
# 서버 상태 확인
curl http://localhost:5001/docs
```
### 빌드 명령어 (CPU)
```bash
# CPU 이미지만 빌드
docker compose -f docker-compose.cpu.yml build
# 캐시 없이 빌드
docker compose -f docker-compose.cpu.yml build --no-cache
```
> **참고**: CPU 기반 PyTorch 이미지는 GPU 이미지(~8GB) 대비 약 70% 용량이 절감된다.
> 단, CPU 추론은 GPU 대비 처리 속도가 느리므로 대용량 이미지 분석 시 시간이 더 소요된다.
### 실행 명령어 (CPU)
```bash
# 백그라운드 실행
docker compose -f docker-compose.cpu.yml up -d
# 포그라운드 실행 (로그 바로 출력)
docker compose -f docker-compose.cpu.yml up
# 중지
docker compose -f docker-compose.cpu.yml down
```
### 로컬 직접 실행 (Docker 없이)
```bash
# GPU 있으면 자동으로 cuda:0 사용, 없으면 cpu로 폴백
python api.py
# device 강제 지정
DEVICE=cpu python api.py
DEVICE=cuda:1 python api.py
```
### GPU/CPU 모드 확인
서버 기동 로그에서 사용 device를 확인할 수 있다:
```
[Inference] 사용 device: cpu ← CPU 모드
[Inference] 사용 device: cuda:0 ← GPU 모드
```

파일 보기

@ -1,84 +0,0 @@
# ==============================================================================
# wing-image-analysis — 드론 영상 유류 분석 FastAPI 서버
#
# Base: PyTorch 1.9.1 + CUDA 11.1 + cuDNN 8
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환 환경)
# GPU: NVIDIA GPU 필수 (MMSegmentation 추론)
# Port: 5001
# ==============================================================================
FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
WORKDIR /app
# ------------------------------------------------------------------------------
# 시스템 패키지: GDAL / PROJ / GEOS (rasterio, geopandas 빌드 의존성)
# libpq-dev: psycopg2-binary 런타임 의존성
# libspatialindex-dev: geopandas 공간 인덱스
# ------------------------------------------------------------------------------
RUN apt-get update && apt-get install -y --no-install-recommends \
gdal-bin \
libgdal-dev \
libproj-dev \
libgeos-dev \
libspatialindex-dev \
gcc \
g++ \
git \
&& rm -rf /var/lib/apt/lists/*
# rasterio는 GDAL 헤더 버전을 맞춰 빌드해야 한다
ENV GDAL_VERSION=3.4.1
# ------------------------------------------------------------------------------
# mmcv-full 1.4.3 — CUDA 11.1 + PyTorch 1.9.0 pre-built 휠
# (소스 컴파일 없이 수 초 내 설치)
# ------------------------------------------------------------------------------
RUN pip install --no-cache-dir \
mmcv-full==1.4.3 \
-f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
# ------------------------------------------------------------------------------
# Python 의존성 설치
# ------------------------------------------------------------------------------
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# ------------------------------------------------------------------------------
# 로컬 mmsegmentation 설치 (mx15hdi/Detect/mmsegmentation/)
# 번들 소스를 먼저 복사한 뒤 editable 설치한다
# ------------------------------------------------------------------------------
COPY mx15hdi/Detect/mmsegmentation/ /tmp/mmsegmentation/
RUN pip install --no-cache-dir -e /tmp/mmsegmentation/
# ------------------------------------------------------------------------------
# 소스 코드 전체 복사
# 대용량 데이터 디렉토리(Original_Images, result 등)는
# docker-compose.yml의 볼륨 마운트로 외부에서 주입된다
# ------------------------------------------------------------------------------
COPY . .
# ------------------------------------------------------------------------------
# .dockerignore로 제외된 런타임 출력 디렉토리를 빈 폴더로 생성
# (볼륨 마운트 전에도 경로가 존재해야 한다)
# ------------------------------------------------------------------------------
RUN mkdir -p \
/app/stitch \
/app/mx15hdi/Detect/Mask_result \
/app/mx15hdi/Detect/result \
/app/mx15hdi/Georeference/Mask_Tif \
/app/mx15hdi/Georeference/Tif \
/app/mx15hdi/Metadata/CSV \
/app/mx15hdi/Metadata/Image/Original_Images \
/app/mx15hdi/Polygon/Shp
# ------------------------------------------------------------------------------
# 런타임 설정
# ------------------------------------------------------------------------------
EXPOSE 5001
# workers=1: GPU 모델을 프로세스 하나에서만 로드 (메모리 공유 불가)
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "5001", "--workers", "1"]

파일 보기

@ -1,112 +0,0 @@
# ==============================================================================
# wing-image-analysis — 드론 영상 유류 분석 FastAPI 서버 (CPU 전용)
#
# Base: python:3.9-slim + PyTorch 1.9.0 CPU 빌드
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환 환경)
# python:3.9 필수 — numpy 1.26.4, geopandas 0.14.4가 Python >=3.9 요구
# GPU: 불필요 (CPU 추론)
# Port: 5001
# ==============================================================================
FROM python:3.9-slim
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
DEVICE=cpu
WORKDIR /app
# ------------------------------------------------------------------------------
# 시스템 패키지: GDAL / PROJ / GEOS (rasterio, geopandas 빌드 의존성)
# libspatialindex-dev: geopandas 공간 인덱스
# opencv-contrib-python-headless 런타임 SO 의존성 (python:3.9-slim에 미포함):
# libgl1 — libGL.so.1
# libglib2.0-0 — libgthread-2.0.so.0, libgobject-2.0.so.0, libglib-2.0.so.0
# libsm6 — libSM.so.6
# libxext6 — libXext.so.6
# libxrender1 — libXrender.so.1
# libgomp1 — libgomp.so.1 (OpenMP, numpy/opencv 병렬 처리)
# ------------------------------------------------------------------------------
RUN apt-get update && apt-get install -y --no-install-recommends \
gdal-bin \
libgdal-dev \
libproj-dev \
libgeos-dev \
libspatialindex-dev \
libgl1 \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender1 \
libgomp1 \
gcc \
g++ \
git \
&& rm -rf /var/lib/apt/lists/*
# rasterio는 GDAL 헤더 버전을 맞춰 빌드해야 한다
ENV GDAL_VERSION=3.4.1
# ------------------------------------------------------------------------------
# GDAL Python 바인딩 (osgeo 모듈) — 시스템 GDAL 버전과 일치해야 한다
# python:3.9-slim은 conda 없이 pip 환경이므로 명시적 설치 필요
# ------------------------------------------------------------------------------
RUN pip install --no-cache-dir GDAL=="$(gdal-config --version)"
# ------------------------------------------------------------------------------
# PyTorch 1.9.0 CPU 버전 설치
# (mmsegmentation 0.25.0 / mmcv-full 1.4.3 호환)
# ------------------------------------------------------------------------------
RUN pip install --no-cache-dir \
torch==1.9.0+cpu \
torchvision==0.10.0+cpu \
-f https://download.pytorch.org/whl/torch_stable.html
# ------------------------------------------------------------------------------
# mmcv-full 1.4.3 CPU 휠 (CUDA ops 없는 경량 빌드, 추론에 충분)
# ------------------------------------------------------------------------------
RUN pip install --no-cache-dir \
mmcv-full==1.4.3 \
-f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html
# ------------------------------------------------------------------------------
# Python 의존성 설치
# ------------------------------------------------------------------------------
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# ------------------------------------------------------------------------------
# 로컬 mmsegmentation 설치 (mx15hdi/Detect/mmsegmentation/)
# 번들 소스를 먼저 복사한 뒤 editable 설치한다
# ------------------------------------------------------------------------------
COPY mx15hdi/Detect/mmsegmentation/ /tmp/mmsegmentation/
RUN pip install --no-cache-dir -e /tmp/mmsegmentation/
# ------------------------------------------------------------------------------
# 소스 코드 전체 복사
# 대용량 데이터 디렉토리(Original_Images, result 등)는
# docker-compose.cpu.yml의 볼륨 마운트로 외부에서 주입된다
# ------------------------------------------------------------------------------
COPY . .
# ------------------------------------------------------------------------------
# .dockerignore로 제외된 런타임 출력 디렉토리를 빈 폴더로 생성
# (볼륨 마운트 전에도 경로가 존재해야 한다)
# ------------------------------------------------------------------------------
RUN mkdir -p \
/app/stitch \
/app/mx15hdi/Detect/Mask_result \
/app/mx15hdi/Detect/result \
/app/mx15hdi/Georeference/Mask_Tif \
/app/mx15hdi/Georeference/Tif \
/app/mx15hdi/Metadata/CSV \
/app/mx15hdi/Metadata/Image/Original_Images \
/app/mx15hdi/Polygon/Shp
# ------------------------------------------------------------------------------
# 런타임 설정
# ------------------------------------------------------------------------------
EXPOSE 5001
# workers=1: 모델을 프로세스 하나에서만 로드 (메모리 공유 불가)
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "5001", "--workers", "1"]

파일 보기

@ -1,340 +0,0 @@
import sys
import os
from pathlib import Path
from contextlib import asynccontextmanager
import asyncio
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import Response, FileResponse
import subprocess
import rasterio
import numpy as np
from PIL import Image
from PIL.ExifTags import TAGS
import io
import base64
from pyproj import Transformer
from extract_data import get_metadata as get_meta
from extract_data import get_oil_type as get_oil
import time
from typing import List, Optional
import shutil
from datetime import datetime
from collections import Counter
# mx15hdi 파이프라인 모듈 임포트를 위한 sys.path 설정
_BASE_DIR = Path(__file__).parent
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Detect'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Metadata' / 'Scripts'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Georeference' / 'Scripts'))
sys.path.insert(0, str(_BASE_DIR / 'mx15hdi' / 'Polygon' / 'Scripts'))
from Inference import load_model, run_inference
from Export_Metadata_mx15hdi import run_metadata_export
from Create_Georeferenced_Images_nadir import run_georeference
from Oilshape import run_oilshape
# AI 모델 (서버 시작 시 1회 로드)
_model = None
# CPU/GPU 바운드 작업용 스레드 풀
_executor = ThreadPoolExecutor(max_workers=4)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""서버 시작 시 AI 모델을 1회 로드하고, 종료 시 해제한다."""
global _model
print("AI 모델 로딩 중 (epoch_165.pth)...")
_model = load_model()
print("AI 모델 로드 완료")
yield
_model = None
app = FastAPI(lifespan=lifespan)
def check_gps_info(image_path: str):
# Pillow로 이미지 열기
image = Image.open(image_path)
# EXIF 데이터 추출
exifdata = image.getexif()
if not exifdata:
print("EXIF 정보를 찾을 수 없습니다.")
return False
# GPS 정보 추출
gps_ifd = exifdata.get_ifd(0x8825) # GPS IFD 태그
if not gps_ifd:
print("GPS 정보를 찾을 수 없습니다.")
return False
return True
def check_camera_info(image_file):
# Pillow로 이미지 열기
image = Image.open(image_file)
# EXIF 데이터 추출
exifdata = image.getexif()
if not exifdata:
print("EXIF 정보를 찾을 수 없습니다.")
return False
for tag_id, value in exifdata.items():
tag_name = TAGS.get(tag_id, tag_id)
if tag_name == "Model":
return value.strip() if isinstance(value, str) else value
async def _run_mx15hdi_pipeline(file_id: str):
"""
mx15hdi 파이프라인을 in-process로 실행한다.
- Step 1 (AI 추론) + Step 2 (메타데이터 추출) 병렬 실행
- Step 3 (지리참조) Step 4 (폴리곤 추출) 순차 실행
- 중간 파일 I/O 없이 numpy 배열을 메모리로 전달
"""
loop = asyncio.get_event_loop()
# Step 1 + Step 2 병렬 실행 — inference_cache 캡처
inference_cache, _ = await asyncio.gather(
loop.run_in_executor(_executor, run_inference, _model, file_id),
loop.run_in_executor(_executor, run_metadata_export, file_id),
)
# Step 3: Georeference — inference_cache 메모리로 전달, georef_cache 반환
georef_cache = await loop.run_in_executor(
_executor, run_georeference, file_id, inference_cache
)
# Step 4: Polygon 추출 — georef_cache 메모리로 전달 (Mask_Tif 디스크 읽기 없음)
await loop.run_in_executor(_executor, run_oilshape, file_id, georef_cache)
# 전체 과정을 구동하는 api
@app.post("/run-script/")
async def run_script(
# pollId: int = Form(...),
camTy: str = Form(...),
fileId: str = Form(...),
image: UploadFile = File(...)
):
try:
print("start")
start_time = time.perf_counter()
if camTy not in ["mx15hdi", "starsafire"]:
raise HTTPException(status_code=400, detail="string1 must be 'mx15hdi' or 'starsafire'")
# 저장할 이미지 경로 설정
upload_dir = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
os.makedirs(upload_dir, exist_ok=True)
# 이미지 파일 저장
image_path = os.path.join(upload_dir, image.filename)
with open(image_path, "wb") as f:
f.write(await image.read())
gps_flage = check_gps_info(image_path)
if not gps_flage:
return {"detail": "GPS Infomation Not Found"}
if camTy == "mx15hdi":
# in-process 파이프라인 실행 (모델 재로딩 없음, Step1+2 병렬)
await _run_mx15hdi_pipeline(fileId)
else:
# starsafire: 기존 subprocess 방식 유지
script_dir = os.path.join(os.getcwd(), camTy, "Main")
script_file = "Combine_module.py"
script_path = os.path.join(script_dir, script_file)
if not os.path.exists(script_path):
raise HTTPException(status_code=404, detail="Script not found")
result = subprocess.run(
["python", script_file, fileId],
cwd=script_dir,
capture_output=True,
text=True,
timeout=300
)
print(f"Subprocess stdout: {result.stdout}")
print(f"Subprocess stderr: {result.stderr}")
meta_string = get_meta(camTy, fileId)
oil_data = get_oil(camTy, fileId)
end_time = time.perf_counter()
print(f"Run time: {end_time - start_time:.4f} sec")
return {
"meta": meta_string,
"data": oil_data
}
except subprocess.TimeoutExpired:
raise HTTPException(status_code=500, detail="Script execution timed out")
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-metadata/{camTy}/{fileId}")
async def get_metadata(camTy: str, fileId: str):
try:
meta_string = get_meta(camTy, fileId)
oil_data = get_oil(camTy, fileId)
return {
"meta": meta_string,
"data": oil_data
}
except subprocess.TimeoutExpired:
raise HTTPException(status_code=500, detail="Script execution timed out")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-original-image/{camTy}/{fileId}")
async def get_original_image(camTy: str, fileId: str):
try:
image_path = os.path.join(camTy, "Metadata/Image/Original_Images", fileId)
files = os.listdir(image_path)
target_file = [f for f in files if f.endswith(".png") or f.endswith(".jpg")]
image_file = os.path.join(image_path, target_file[0])
with open(image_file, "rb") as origin_image:
base64_string = base64.b64encode(origin_image.read()).decode("utf-8")
print(base64_string[:100])
return base64_string
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/get-image/{camTy}/{fileId}")
async def get_image(camTy: str, fileId: str):
try:
tif_file_path = os.path.join(camTy, "Georeference/Tif", fileId)
files = os.listdir(tif_file_path)
target_file = [f for f in files if f.endswith(".tif")]
tif_file = os.path.join(tif_file_path, target_file[0])
with rasterio.open(tif_file) as dataset:
crs = dataset.crs
bounds = dataset.bounds
if crs != "EPSG:4326":
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
minx, miny = transformer.transform(bounds.left, bounds.bottom)
maxx, maxy = transformer.transform(bounds.right, bounds.top)
print(minx, miny, maxx, maxy)
data = dataset.read()
if data.shape[0] == 1:
image_data = data[0]
else:
image_data = np.moveaxis(data, 0, -1)
image = Image.fromarray(image_data)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
base64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
print(base64_string[:100])
return {
"minLon": minx,
"minLat": miny,
"maxLon": maxx,
"maxLat": maxy,
"image": base64_string
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
BASE_DIR = Path(__file__).parent
PIC_GPS_SCRIPT = BASE_DIR / "pic_gps.py"
@app.post("/stitch")
async def stitch(
files: List[UploadFile] = File(..., description="합성할 이미지 파일들 (2장 이상)"),
fileId: str = Form(...)
):
if len(files) < 2:
raise HTTPException(
status_code=400,
detail="최소 2장 이상의 이미지가 필요합니다."
)
try:
today = datetime.now().strftime("%Y%m%d")
upload_dir = BASE_DIR / "stitch" / fileId
upload_dir.mkdir(parents=True, exist_ok=True)
model_list = []
for idx, file in enumerate(files):
model = check_camera_info(file.file)
model_list.append(model)
original_filename = file.filename or f"image_{idx}.jpg"
filename = f"{model}_{idx:03d}_{original_filename}"
file_path = upload_dir / filename
output_filename = f"stitched_{fileId}.jpg"
output_path = upload_dir / output_filename
# 파일 저장
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
model_counter = Counter(model_list)
most_common_model = model_counter.most_common(1)
cmd = [
"python",
str(PIC_GPS_SCRIPT),
"--mode", "drone",
"--input", str(upload_dir),
"--out", str(output_path),
"--model", most_common_model[0][0],
"--enhance"
]
print(cmd)
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300
)
print(f"Subprocess stdout: {result.stdout}")
if result.returncode != 0:
print(f"Subprocess stderr: {result.stderr}")
raise HTTPException(status_code=500, detail=f"Script failed: {result.stderr}")
return FileResponse(
path=str(output_path),
media_type="image/jpeg",
filename=output_filename
)
except HTTPException:
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)

파일 보기

@ -1,46 +0,0 @@
version: "3.9"
# CPU 전용 docker-compose 설정
# GPU(nvidia-container-toolkit) 없이도 실행 가능
# 실행: docker compose -f docker-compose.cpu.yml up -d --build
services:
image-analysis:
build:
context: .
dockerfile: Dockerfile.cpu
image: wing-image-analysis:cpu
container_name: wing-image-analysis
ports:
- "5001:5001"
environment:
- DEVICE=cpu
volumes:
# ── mx15hdi (EO 드론 카메라) ────────────────────────────────────────
# 입력: 업로드된 원본 이미지
- ./mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images
# 출력: 메타데이터 CSV
- ./mx15hdi/Metadata/CSV:/app/mx15hdi/Metadata/CSV
# 출력: 지리참조 GeoTIFF (컬러 / 마스크)
- ./mx15hdi/Georeference/Tif:/app/mx15hdi/Georeference/Tif
- ./mx15hdi/Georeference/Mask_Tif:/app/mx15hdi/Georeference/Mask_Tif
# 출력: 유류 폴리곤 Shapefile
- ./mx15hdi/Polygon/Shp:/app/mx15hdi/Polygon/Shp
# 출력: 블렌딩 추론 결과 / 마스크 이미지
- ./mx15hdi/Detect/result:/app/mx15hdi/Detect/result
- ./mx15hdi/Detect/Mask_result:/app/mx15hdi/Detect/Mask_result
# ── starsafire (열화상 카메라) ──────────────────────────────────────
- ./starsafire/Metadata/Image/Original_Images:/app/starsafire/Metadata/Image/Original_Images
- ./starsafire/Metadata/CSV:/app/starsafire/Metadata/CSV
- ./starsafire/Georeference/Tif:/app/starsafire/Georeference/Tif
- ./starsafire/Georeference/Mask_Tif:/app/starsafire/Georeference/Mask_Tif
- ./starsafire/Polygon/Shp:/app/starsafire/Polygon/Shp
- ./starsafire/Detect/result:/app/starsafire/Detect/result
- ./starsafire/Detect/Mask_result:/app/starsafire/Detect/Mask_result
# ── 스티칭 결과 ─────────────────────────────────────────────────────
- ./stitch:/app/stitch
# GPU deploy 섹션 없음 — CPU 전용 실행
restart: unless-stopped

파일 보기

@ -1,47 +0,0 @@
version: "3.9"
services:
image-analysis:
build:
context: .
dockerfile: Dockerfile
image: wing-image-analysis:latest
container_name: wing-image-analysis
ports:
- "5001:5001"
volumes:
# ── mx15hdi (EO 드론 카메라) ────────────────────────────────────────
# 입력: 업로드된 원본 이미지
- ./mx15hdi/Metadata/Image/Original_Images:/app/mx15hdi/Metadata/Image/Original_Images
# 출력: 메타데이터 CSV
- ./mx15hdi/Metadata/CSV:/app/mx15hdi/Metadata/CSV
# 출력: 지리참조 GeoTIFF (컬러 / 마스크)
- ./mx15hdi/Georeference/Tif:/app/mx15hdi/Georeference/Tif
- ./mx15hdi/Georeference/Mask_Tif:/app/mx15hdi/Georeference/Mask_Tif
# 출력: 유류 폴리곤 Shapefile
- ./mx15hdi/Polygon/Shp:/app/mx15hdi/Polygon/Shp
# 출력: 블렌딩 추론 결과 / 마스크 이미지
- ./mx15hdi/Detect/result:/app/mx15hdi/Detect/result
- ./mx15hdi/Detect/Mask_result:/app/mx15hdi/Detect/Mask_result
# ── starsafire (열화상 카메라) ──────────────────────────────────────
- ./starsafire/Metadata/Image/Original_Images:/app/starsafire/Metadata/Image/Original_Images
- ./starsafire/Metadata/CSV:/app/starsafire/Metadata/CSV
- ./starsafire/Georeference/Tif:/app/starsafire/Georeference/Tif
- ./starsafire/Georeference/Mask_Tif:/app/starsafire/Georeference/Mask_Tif
- ./starsafire/Polygon/Shp:/app/starsafire/Polygon/Shp
- ./starsafire/Detect/result:/app/starsafire/Detect/result
- ./starsafire/Detect/Mask_result:/app/starsafire/Detect/Mask_result
# ── 스티칭 결과 ─────────────────────────────────────────────────────
- ./stitch:/app/stitch
# NVIDIA GPU 할당 (nvidia-container-toolkit 필수)
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped

파일 보기

@ -1,97 +0,0 @@
import csv
from datetime import datetime
from pathlib import Path
import geopandas as gpd
import json
def get_metadata(camTy: str, fileId: str):
# CSV 파일 경로 설정
# base_dir = "mx15hdi" if pollId == "1" else "starsafire"
if camTy == "mx15hdi":
csv_path = f"{camTy}/Metadata/CSV/{fileId}/mx15hdi_interpolation.csv"
elif camTy == "starsafire":
csv_path = f"{camTy}/Metadata/CSV/{fileId}/Metadata_Extracted.csv"
try:
# CSV 파일 읽기
with open(csv_path, 'r', newline='', encoding='utf-8-sig') as csvfile:
reader = csv.reader(csvfile)
next(reader, None)
row = next(reader, None)
return ','.join(row)
except FileNotFoundError:
print(f"CSV file not found: {csv_path}")
raise
except ValueError as e:
print(f"Value error: {str(e)}")
raise
except Exception as e:
print(f"Error processing CSV: {e}")
raise
def get_oil_type(camTy: str, fileId: str):
# Shapefile 경로 설정
path = f"{camTy}/Polygon/Shp/{fileId}"
shp_file = list(Path(path).glob("*.shp"))
if not shp_file:
return []
shp_path = f"{camTy}/Polygon/Shp/{fileId}/{shp_file[0].name}"
print(shp_path)
# if camTy == "mx15hdi":
# fileSub = f"{Path(fileName).stem}_gsd"
# elif camTy == "starsafire":
# fileSub = f"{Path(fileName).stem}"
# shp_path = f"{camTy}/Polygon/Shp/{fileId}/{fileSub}.shp"
# 두께 정보
class_thickness_mm = {
1: 1.0, # Black oil (Emulsion)
2: 0.1, # Brown oil (Crude)
3: 0.0003, # Rainbow oil (Slick)
4: 0.0001 # Silver oil (Slick)
}
# 알고리즘 정보
algorithm = {
1: "검정",
2: "갈색",
3: "무지개",
4: "은색"
}
try:
# Shapefile 읽기
gdf = gpd.read_file(shp_path)
if gdf.crs != "epsg:4326":
gdf = gdf.to_crs("epsg:4326")
# 데이터 준비
data = []
for _, row in gdf.iterrows():
class_id = row.get('class_id', None)
area_m2 = row.get('area_m2', None)
volume_m3 = row.get('volume_m3', None)
note = row.get('note', None)
thickness_m = class_thickness_mm.get(class_id, 0) / 1000.0
geom_wkt = row.geometry.wkt if row.geometry else None
result = {
"classId": algorithm.get(class_id, 0),
"area": area_m2,
"volume": volume_m3,
"note": note,
"thickness": thickness_m,
"wkt": geom_wkt
}
data.append(result)
return data
except FileNotFoundError:
print(f"Shapefile not found: {shp_path}")
raise
except Exception as e:
print(f"Error processing shapefile or database: {str(e)}")
raise

파일 보기

@ -1,238 +0,0 @@
# 이미지 업로드 유류 분석 기능 구현 계획
## Context
드론/항공 촬영 이미지를 업로드하면 AI 세그멘테이션으로 유류 확산 정보(위치·유종·면적·부피)를 자동 추출하고, 결과를 DB에 저장한 뒤 예측정보 입력 폼에 자동 채워주는 기능이다.
이미지 분석 서버(`prediction/image/api.py`, FastAPI, 포트 5001)는 이미 구현되어 있으며, 프론트↔백엔드↔이미지 분석 서버 연동 및 결과 자동 채우기를 구현한다.
---
## 전체 흐름
```
[프론트] 이미지 선택 → 분석 요청 버튼
↓ POST /api/prediction/image-analyze (multipart: image)
[백엔드]
├─ fileId = UUID 생성
├─ camTy = "mx15hdi" (하드코딩, 추후 이미지 EXIF 카메라 정보로 자동 판별 예정)
├─ 이미지 분석 서버로 전달 POST http://IMAGE_API_URL/run-script/
├─ 응답 파싱: meta(위경도 DMS→십진수 변환), data[0].classId→유종
├─ ACDNT INSERT (lat/lon/임시사고명)
├─ SPIL_DATA INSERT (유종/면적/img_rslt_data JSONB)
└─ 응답: { acdntSn, lat, lon, oilType, area, volume }
[프론트] 폼 자동 채우기 (좌표·유종·유출량)
→ 사용자가 나머지 입력 후 "확산예측 실행"
```
---
## 구현 단계
### Step 1 — DB 마이그레이션 (`database/migration/017_spil_img_rslt.sql`)
`SPIL_DATA` 테이블에 이미지 분석 결과 컬럼 추가.
```sql
ALTER TABLE wing.spil_data
ADD COLUMN IF NOT EXISTS img_rslt_data JSONB;
```
---
### Step 2 — 백엔드: 이미지 분석 엔드포인트
**파일**: `backend/src/prediction/predictionRouter.ts` (라우트 등록)
**신규 파일**: `backend/src/prediction/imageAnalyzeService.ts`
#### 엔드포인트
```
POST /api/prediction/image-analyze
Content-Type: multipart/form-data
Body: image (file)
```
#### `imageAnalyzeService.ts` 핵심 로직
```typescript
// 1. fileId 생성 (crypto.randomUUID)
// 2. 이미지 분석 서버 호출
// camTy는 현재 "mx15hdi"로 하드코딩한다.
// TODO: 추후 이미지 EXIF에서 카메라 모델명을 읽어 camTy를 자동 판별하는 로직을
// 이미지 분석 서버(api.py)에 추가할 예정이다. (check_camera_info 함수 활용)
// FormData: { camTy: 'mx15hdi', fileId, image }
// → POST ${IMAGE_API_URL}/run-script/
// 응답: { meta: string, data: OilPolygon[] }
// 3. meta 문자열 파싱 (mx15hdi CSV 컬럼 순서 사용)
// [Filename, Tlat_d, Tlat_m, Tlat_s, Tlon_d, Tlon_m, Tlon_s, ...]
// DMS → 십진수: d + m/60 + s/3600
// 4. 유종 매핑 (data[0].classId → UI 유종명)
// classId → oilType: { '검정': '벙커C유', '갈색': '벙커C유', '무지개': '경유', '은색': '등유' }
// 5. ACDNT INSERT (임시 사고명 = "이미지분석_YYYY-MM-DD HH:mm", lat, lon, occurredAt = 촬영시각)
// 6. SPIL_DATA INSERT (acdntSn, matTyCd, matVol=data[0].volume, imgRsltData=JSON.stringify(response))
// 7. 반환
interface ImageAnalyzeResult {
acdntSn: number;
lat: number;
lon: number;
oilType: string; // UI 유종명 (벙커C유 등)
area: number; // m²
volume: number; // m³
fileId: string;
}
```
#### 환경변수 추가 (`backend/.env`)
```
IMAGE_API_URL=http://localhost:5001
```
#### 에러 처리
| 조건 | 응답 |
|------|------|
| 이미지에 GPS EXIF 없음 | 422 `{ error: 'GPS_NOT_FOUND' }` |
| 이미지 서버 타임아웃(300s) | 504 |
---
### Step 3 — 프론트엔드: API 서비스
**파일**: `frontend/src/tabs/prediction/services/predictionApi.ts`
```typescript
interface ImageAnalyzeResult {
acdntSn: number;
lat: number;
lon: number;
oilType: string;
area: number;
volume: number;
fileId: string;
}
export const analyzeImage = async (
file: File
): Promise<ImageAnalyzeResult> => {
const formData = new FormData();
formData.append('image', file);
const { data } = await api.post<ImageAnalyzeResult>(
'/prediction/image-analyze',
formData,
{ headers: { 'Content-Type': 'multipart/form-data' }, timeout: 330000 }
);
return data;
};
```
---
### Step 4 — 프론트엔드: Props 타입 확장
**파일**: `frontend/src/tabs/prediction/components/leftPanelTypes.ts`
```typescript
// 기존 Props에 추가
onImageAnalysisResult?: (result: ImageAnalyzeResult) => void;
```
---
### Step 5 — 프론트엔드: PredictionInputSection 수정
**파일**: `frontend/src/tabs/prediction/components/PredictionInputSection.tsx`
#### 변경 사항
1. **"이미지 분석 실행" 버튼** (이미지 선택 후 활성화)
- 클릭 시 `analyzeImage(file)` 호출 (camTy는 백엔드에서 처리)
- 로딩 스피너 표시 (분석 소요시간 수십 초~수 분)
2. **분석 결과 표시** (성공 시)
- "분석 완료: 위도 XX.XXXX / 경도 XXX.XXXX / 유종: OO" 요약 메시지
3. **`onImageAnalysisResult` 콜백 호출**
- 분석 성공 시 부모로 결과 전달
4. **에러 처리**
- GPS_NOT_FOUND: "GPS 정보가 없는 이미지입니다" 메시지 표시
- 타임아웃: "분석 서버 응답 없음" 메시지 표시
5. **로컬 상태 교체**: `uploadedImage` (Base64 DataURL) 제거, `uploadedFile: File | null`로 교체
---
### Step 6 — 프론트엔드: OilSpillView 결과 처리
**파일**: `frontend/src/tabs/prediction/components/OilSpillView.tsx`
```typescript
const handleImageAnalysisResult = useCallback((result: ImageAnalyzeResult) => {
// 1. 사고 좌표 자동 채우기
setIncidentCoord({ lat: result.lat, lon: result.lon })
setFlyToCoord({ lat: result.lat, lon: result.lon })
// 2. 유종/유출량 자동 채우기
setOilType(result.oilType)
setSpillAmount(parseFloat(result.volume.toFixed(4)))
setSpillUnit('m³')
// 3. 분석 선택 상태 갱신 (acdntSn 연결 — 시뮬레이션 실행 시 기존 사고 사용)
setSelectedAnalysis({
acdntSn: result.acdntSn,
acdntNm: '',
// ... 나머지 기본값
})
}, [])
```
`LeftPanel``onImageAnalysisResult={handleImageAnalysisResult}` 전달.
---
## 수정 대상 파일 요약
| 파일 | 변경 유형 |
|------|---------|
| `database/migration/017_spil_img_rslt.sql` | **신규** — SPIL_DATA 컬럼 추가 |
| `backend/src/prediction/imageAnalyzeService.ts` | **신규** — 이미지 분석 서비스 |
| `backend/src/prediction/predictionRouter.ts` | **수정** — 라우트 추가 |
| `backend/.env` | **수정** — IMAGE_API_URL 추가 |
| `frontend/src/tabs/prediction/services/predictionApi.ts` | **수정** — analyzeImage 함수 추가 |
| `frontend/src/tabs/prediction/components/leftPanelTypes.ts` | **수정** — Props 타입 추가 |
| `frontend/src/tabs/prediction/components/PredictionInputSection.tsx` | **수정** — 분석 실행 UI |
| `frontend/src/tabs/prediction/components/OilSpillView.tsx` | **수정** — 결과 처리 핸들러 |
---
## 검증 방법
1. **이미지 분석 서버 직접 테스트**
```bash
curl -X POST http://localhost:5001/run-script/ \
-F "camTy=mx15hdi" -F "fileId=test001" -F "image=@drone_image.jpg"
```
2. **백엔드 엔드포인트 테스트**
```bash
curl -X POST http://localhost:3001/api/prediction/image-analyze \
-F "image=@drone_image.jpg" \
-H "Cookie: <auth_cookie>"
```
- 응답: `{ acdntSn, lat, lon, oilType, area, volume, fileId }`
- DB 확인: ACDNT, SPIL_DATA 레코드 생성 여부
3. **프론트엔드 E2E 테스트**
- 이미지 업로드 모드 선택 → GPS EXIF 있는 이미지 업로드 → "이미지 분석 실행" 클릭
- 로딩 표시 → 완료 시: 지도 이동, 유종/좌표 폼 자동 채워짐 확인
- 나머지 필드(예상시각·유출시간 등) 직접 입력 후 "확산예측 실행" → 정상 시뮬레이션 확인
4. **에러 케이스 확인**
- GPS 없는 이미지 → "GPS 정보가 없는 이미지입니다" 메시지

파일 보기

@ -1,131 +0,0 @@
import os, mmcv, cv2, json
import numpy as np
import torch
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from mmseg.apis import init_segmentor, inference_segmentor
from shapely.geometry import Polygon, mapping
import sys
_DETECT_DIR = Path(__file__).parent # mx15hdi/Detect/
_MX15HDI_DIR = _DETECT_DIR.parent # mx15hdi/
def load_model():
"""서버 시작 시 1회 호출. 로드된 모델 객체를 반환한다."""
# 우선순위: 환경변수 DEVICE > GPU 자동감지 > CPU 폴백
env_device = os.environ.get('DEVICE', '').strip()
if env_device:
device = env_device
elif torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
print(f'[Inference] 사용 device: {device}')
config = str(_DETECT_DIR / 'V7_SPECIAL.py')
checkpoint = str(_DETECT_DIR / 'epoch_165.pth')
model = init_segmentor(config, checkpoint, device=device)
model.PALETTE = [
[0, 0, 0], # background
[0, 0, 204], # black
[180, 180, 180], # brown
[255, 255, 0], # rainbow
[178, 102, 255] # silver
]
return model
def blend_images(original_img, color_mask, alpha=0.6):
"""
Blend original image and color mask with alpha transparency.
Inputs: numpy arrays HWC uint8
"""
blended = cv2.addWeighted(original_img, 1 - alpha, color_mask, alpha, 0)
return blended
def run_inference(model, file_id: str, write_files: bool = False) -> dict:
"""
사전 로드된 모델로 file_id 폴더 이미지를 세그멘테이션한다.
Args:
model: load_model() 로드된 모델 객체.
file_id: 처리할 이미지 폴더명.
write_files: True이면 Detect/result/ Detect/Mask_result/ 중간 파일 저장.
False이면 디스크 쓰기 생략 (기본값).
Returns:
inference_cache: {image_filename: {'blended': ndarray, 'mask': ndarray, 'ext': str}}
값을 run_georeference() 전달하면 중간 파일 읽기를 생략할 있다.
"""
img_path = str(_MX15HDI_DIR / 'Metadata' / 'Image' / 'Original_Images' / file_id)
output_folder = str(_DETECT_DIR / 'result' / file_id)
mask_folder = str(_DETECT_DIR / 'Mask_result' / file_id)
if not os.path.exists(img_path):
raise FileNotFoundError(f"이미지 폴더가 존재하지 않습니다: {img_path}")
if write_files:
os.makedirs(output_folder, exist_ok=True)
os.makedirs(mask_folder, exist_ok=True)
image_files = [
f for f in os.listdir(img_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
]
# palette_array는 이미지마다 동일하므로 루프 외부에서 1회 생성
palette_array = np.array(model.PALETTE, dtype=np.uint8)
inference_cache = {}
for image_file in tqdm(image_files, desc="Processing images"):
image_path = os.path.join(img_path, image_file)
image_name, image_ext = os.path.splitext(image_file)
image_ext = image_ext.lower()
# 이미지를 1회만 읽어 inference와 blending 모두에 재사용
img_bgr = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# 이미 로드된 배열을 inference_segmentor에 전달 (경로 전달 시 내부에서 재읽기 발생)
result = inference_segmentor(model, img_bgr)
seg_map = result[0]
# Create color mask from palette
color_mask = palette_array[seg_map]
# blended image
blended = blend_images(img_rgb, color_mask, alpha=0.6)
blended_bgr = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR)
# mask — numpy 슬라이싱으로 cv2.cvtColor 호출 1회 제거
mask_bgr = color_mask[:, :, ::-1].copy()
# 결과를 메모리 캐시에 저장 (georeference 단계에서 재사용)
# mask는 palette 원본(RGB) 그대로 저장 — Oilshape의 class_colors가 RGB 기준이므로 BGR로 저장 시 색상 매칭 실패
inference_cache[image_file] = {
'blended': blended_bgr,
'mask': color_mask,
'ext': image_ext,
}
if write_files:
cv2.imwrite(
os.path.join(output_folder, f"{image_name}{image_ext}"),
blended_bgr,
[cv2.IMWRITE_JPEG_QUALITY, 85]
)
cv2.imwrite(os.path.join(mask_folder, f"{image_name}{image_ext}"), mask_bgr)
return inference_cache
if __name__ == '__main__':
if len(sys.argv) < 2:
raise ValueError("파라미터가 제공되지 않았습니다. 폴더 이름을 명령줄 인자로 입력해주세요.")
_model = load_model()
# CLI 단독 실행 시에는 중간 파일도 디스크에 저장
run_inference(_model, sys.argv[1], write_files=True)

파일 보기

@ -1,196 +0,0 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(
type='ResNetV1c',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DAHead',
in_channels=2048,
in_index=3,
channels=512,
pam_channels=64,
dropout_ratio=0.1,
num_classes=5,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=5,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
dataset_type = 'CustomDataset'
data_root = 'data/my_dataset_v7'
img_norm_cfg = dict(
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512)),
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(
type='Normalize',
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True),
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='CustomDataset',
data_root='data/my_dataset_v7',
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512)),
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(
type='Normalize',
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True),
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]),
val=dict(
type='CustomDataset',
data_root='data/my_dataset_v7',
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='CustomDataset',
data_root='data/my_dataset_v7',
img_dir='img_dir/test',
ann_dir='ann_dir/test',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[119.54541993, 107.13545011, 96.71320316],
std=[60.3273945, 56.33692515, 55.71005772],
to_rgb=True),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
],
split=None,
img_suffix='.png',
seg_map_suffix='.png'))
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
#workflow = [('train', 1), ('val', 1)]
workflow = [('test', 1)]
cudnn_benchmark = True
optimizer = dict(
type='AdamW',
lr=3e-05,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys=dict(
pos_block=dict(decay_mult=0.0),
norm=dict(decay_mult=0.0),
head=dict(lr_mult=10.0))))
optimizer_config = dict()
lr_config = dict(
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-06,
power=1.0,
min_lr=0.0,
by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=200)
checkpoint_config = dict(by_epoch=True, interval=1)
evaluation = dict(by_epoch=True, interval=1, metric='mIoU')
log_config = dict(
interval=1000,
hooks=[
dict(type='TextLoggerHook'),
dict(
type='WandbLoggerHook',
init_kwargs=dict(project='Oil_Spill', name='V7_SPECIAL'))
])
auto_resume = False
work_dir = 'work_dirs/V7_SPECIAL'
gpu_ids = [0]

파일 보기

@ -1,161 +0,0 @@
version: 2.1
jobs:
lint:
docker:
- image: cimg/python:3.7.4
steps:
- checkout
- run:
name: Install dependencies
command: |
sudo apt-add-repository ppa:brightbox/ruby-ng -y
sudo apt-get update
sudo apt-get install -y ruby2.7
- run:
name: Install pre-commit hook
command: |
pip install pre-commit
pre-commit install
- run:
name: Linting
command: pre-commit run --all-files
- run:
name: Check docstring coverage
command: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 50 mmseg
build_cpu:
parameters:
# The python version must match available image tags in
# https://circleci.com/developer/images/image/cimg/python
python:
type: string
default: "3.7.4"
torch:
type: string
torchvision:
type: string
docker:
- image: cimg/python:<< parameters.python >>
resource_class: large
steps:
- checkout
- run:
name: Install Libraries
command: |
sudo apt-get update
sudo apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx libjpeg-dev zlib1g-dev libtinfo-dev libncurses5
- run:
name: Configure Python & pip
command: |
python -m pip install --upgrade pip
python -m pip install wheel
- run:
name: Install PyTorch
command: |
python -V
python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
- run:
name: Install mmseg dependencies
command: |
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
python -m pip install mmdet
python -m pip install -r requirements.txt
- run:
name: Build and install
command: |
python -m pip install -e .
- run:
name: Run unittests
command: |
python -m pip install timm
python -m coverage run --branch --source mmseg -m pytest tests/
python -m coverage xml
python -m coverage report -m
build_cu101:
machine:
image: ubuntu-1604-cuda-10.1:201909-23
resource_class: gpu.nvidia.small
steps:
- checkout
- run:
name: Install Libraries
command: |
sudo apt-get update
sudo apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx
- run:
name: Configure Python & pip
command: |
pyenv global 3.7.0
python -m pip install --upgrade pip
python -m pip install wheel
- run:
name: Install PyTorch
command: |
python -V
python -m pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
- run:
name: Install mmseg dependencies
# python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
command: |
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
python -m pip install mmdet
python -m pip install -r requirements.txt
- run:
name: Build and install
command: |
python setup.py check -m -s
TORCH_CUDA_ARCH_LIST=7.0 python -m pip install -e .
- run:
name: Run unittests
command: |
python -m pip install timm
python -m pytest tests/
workflows:
unit_tests:
jobs:
- lint
- build_cpu:
name: build_cpu_th1.6
torch: 1.6.0
torchvision: 0.7.0
requires:
- lint
- build_cpu:
name: build_cpu_th1.7
torch: 1.7.0
torchvision: 0.8.1
requires:
- lint
- build_cpu:
name: build_cpu_th1.8_py3.9
torch: 1.8.0
torchvision: 0.9.0
python: "3.9.0"
requires:
- lint
- build_cpu:
name: build_cpu_th1.9_py3.8
torch: 1.9.0
torchvision: 0.10.0
python: "3.8.0"
requires:
- lint
- build_cpu:
name: build_cpu_th1.9_py3.9
torch: 1.9.0
torchvision: 0.10.0
python: "3.9.0"
requires:
- lint
- build_cu101:
requires:
- build_cpu_th1.6
- build_cpu_th1.7
- build_cpu_th1.8_py3.9
- build_cpu_th1.9_py3.8
- build_cpu_th1.9_py3.9

파일 보기

@ -1,133 +0,0 @@
# yapf: disable
# Inference Speed is tested on NVIDIA V100
hrnet = [
dict(
config='configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py',
checkpoint='fcn_hr18s_512x512_160k_ade20k_20200614_214413-870f65ac.pth', # noqa
eval='mIoU',
metric=dict(mIoU=33.0),
),
dict(
config='configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py',
checkpoint='fcn_hr18s_512x1024_160k_cityscapes_20200602_190901-4a0797ea.pth', # noqa
eval='mIoU',
metric=dict(mIoU=76.31),
),
dict(
config='configs/hrnet/fcn_hr48_512x512_160k_ade20k.py',
checkpoint='fcn_hr48_512x512_160k_ade20k_20200614_214407-a52fc02c.pth',
eval='mIoU',
metric=dict(mIoU=42.02),
),
dict(
config='configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py',
checkpoint='fcn_hr48_512x1024_160k_cityscapes_20200602_190946-59b7973e.pth', # noqa
eval='mIoU',
metric=dict(mIoU=80.65),
),
]
pspnet = [
dict(
config='configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py',
checkpoint='pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth', # noqa
eval='mIoU',
metric=dict(mIoU=78.55),
),
dict(
config='configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py',
checkpoint='pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth', # noqa
eval='mIoU',
metric=dict(mIoU=79.76),
),
dict(
config='configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py',
checkpoint='pspnet_r101-d8_512x512_160k_ade20k_20200615_100650-967c316f.pth', # noqa
eval='mIoU',
metric=dict(mIoU=44.39),
),
dict(
config='configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py',
checkpoint='pspnet_r50-d8_512x512_160k_ade20k_20200615_184358-1890b0bd.pth', # noqa
eval='mIoU',
metric=dict(mIoU=42.48),
),
]
resnest = [
dict(
config='configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py',
checkpoint='pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth', # noqa
eval='mIoU',
metric=dict(mIoU=45.44),
),
dict(
config='configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py',
checkpoint='pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth', # noqa
eval='mIoU',
metric=dict(mIoU=78.57),
),
]
fastscnn = [
dict(
config='configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py',
checkpoint='fast_scnn_8x4_160k_lr0.12_cityscapes-0cec9937.pth',
eval='mIoU',
metric=dict(mIoU=70.96),
)
]
deeplabv3plus = [
dict(
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py', # noqa
checkpoint='deeplabv3plus_r101-d8_769x769_80k_cityscapes_20200607_000405-a7573d20.pth', # noqa
eval='mIoU',
metric=dict(mIoU=80.98),
),
dict(
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py', # noqa
checkpoint='deeplabv3plus_r101-d8_512x1024_80k_cityscapes_20200606_114143-068fcfe9.pth', # noqa
eval='mIoU',
metric=dict(mIoU=80.97),
),
dict(
config='configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py', # noqa
checkpoint='deeplabv3plus_r50-d8_512x1024_80k_cityscapes_20200606_114049-f9fb496d.pth', # noqa
eval='mIoU',
metric=dict(mIoU=80.09),
),
dict(
config='configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py', # noqa
checkpoint='deeplabv3plus_r50-d8_769x769_80k_cityscapes_20200606_210233-0e9dfdc4.pth', # noqa
eval='mIoU',
metric=dict(mIoU=79.83),
),
]
vit = [
dict(
config='configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py',
checkpoint='upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth',
eval='mIoU',
metric=dict(mIoU=47.73),
),
dict(
config='configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py',
checkpoint='upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth',
eval='mIoU',
metric=dict(mIoU=43.52),
),
]
fp16 = [
dict(
config='configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py', # noqa
checkpoint='deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes_20200717_230920-f1104f4b.pth', # noqa
eval='mIoU',
metric=dict(mIoU=80.46),
)
]
swin = [
dict(
config='configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py', # noqa
checkpoint='upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth', # noqa
eval='mIoU',
metric=dict(mIoU=44.41),
)
]
# yapf: enable

파일 보기

@ -1,19 +0,0 @@
configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py
configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py
configs/hrnet/fcn_hr48_512x512_160k_ade20k.py
configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py
configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py
configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py
configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py
configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py
configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py
configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py
configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py
configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py
configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py
configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py
configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py
configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py
configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py
configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py
configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py

파일 보기

@ -1,41 +0,0 @@
PARTITION=$1
CHECKPOINT_DIR=$2
echo 'configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr18s_512x512_160k_ade20k configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py $CHECKPOINT_DIR/fcn_hr18s_512x512_160k_ade20k_20200614_214413-870f65ac.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr18s_512x512_160k_ade20k --cfg-options dist_params.port=28171 &
echo 'configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr18s_512x1024_160k_cityscapes configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py $CHECKPOINT_DIR/fcn_hr18s_512x1024_160k_cityscapes_20200602_190901-4a0797ea.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr18s_512x1024_160k_cityscapes --cfg-options dist_params.port=28172 &
echo 'configs/hrnet/fcn_hr48_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr48_512x512_160k_ade20k configs/hrnet/fcn_hr48_512x512_160k_ade20k.py $CHECKPOINT_DIR/fcn_hr48_512x512_160k_ade20k_20200614_214407-a52fc02c.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr48_512x512_160k_ade20k --cfg-options dist_params.port=28173 &
echo 'configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fcn_hr48_512x1024_160k_cityscapes configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py $CHECKPOINT_DIR/fcn_hr48_512x1024_160k_cityscapes_20200602_190946-59b7973e.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fcn_hr48_512x1024_160k_cityscapes --cfg-options dist_params.port=28174 &
echo 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r50-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r50-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28175 &
echo 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r101-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_r101-d8_512x1024_80k_cityscapes_20200606_112211-e1e1100f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28176 &
echo 'configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r101-d8_512x512_160k_ade20k configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_r101-d8_512x512_160k_ade20k_20200615_100650-967c316f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r101-d8_512x512_160k_ade20k --cfg-options dist_params.port=28177 &
echo 'configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_r50-d8_512x512_160k_ade20k configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_r50-d8_512x512_160k_ade20k_20200615_184358-1890b0bd.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_r50-d8_512x512_160k_ade20k --cfg-options dist_params.port=28178 &
echo 'configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_s101-d8_512x512_160k_ade20k configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py $CHECKPOINT_DIR/pspnet_s101-d8_512x512_160k_ade20k_20200807_145416-a6daa92a.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_s101-d8_512x512_160k_ade20k --cfg-options dist_params.port=28179 &
echo 'configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION pspnet_s101-d8_512x1024_80k_cityscapes configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/pspnet_s101-d8_512x1024_80k_cityscapes_20200807_140631-c75f3b99.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/pspnet_s101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28180 &
echo 'configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION fast_scnn_lr0.12_8x4_160k_cityscapes configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py $CHECKPOINT_DIR/fast_scnn_8x4_160k_lr0.12_cityscapes-0cec9937.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/fast_scnn_lr0.12_8x4_160k_cityscapes --cfg-options dist_params.port=28181 &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_769x769_80k_cityscapes_20200607_000405-a7573d20.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_769x769_80k_cityscapes --cfg-options dist_params.port=28182 &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_512x1024_80k_cityscapes_20200606_114143-068fcfe9.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28183 &
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r50-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r50-d8_512x1024_80k_cityscapes_20200606_114049-f9fb496d.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r50-d8_512x1024_80k_cityscapes --cfg-options dist_params.port=28184 &
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r50-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r50-d8_769x769_80k_cityscapes_20200606_210233-0e9dfdc4.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r50-d8_769x769_80k_cityscapes --cfg-options dist_params.port=28185 &
echo 'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_vit-b16_ln_mln_512x512_160k_ade20k configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py $CHECKPOINT_DIR/upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_vit-b16_ln_mln_512x512_160k_ade20k --cfg-options dist_params.port=28186 &
echo 'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_deit-s16_ln_mln_512x512_160k_ade20k configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py $CHECKPOINT_DIR/upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_deit-s16_ln_mln_512x512_160k_ade20k --cfg-options dist_params.port=28187 &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py $CHECKPOINT_DIR/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes-cc58bc8d.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes --cfg-options dist_params.port=28188 &
echo 'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 tools/slurm_test.sh $PARTITION upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py $CHECKPOINT_DIR/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth --eval mIoU --work-dir work_dirs/benchmark_evaluation/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K --cfg-options dist_params.port=28189 &

파일 보기

@ -1,149 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os
import os.path as osp
import warnings
from argparse import ArgumentParser
import requests
from mmcv import Config
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.utils import get_root_logger
# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
"""Download checkpoint and check if hash code is true."""
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
r = requests.get(url)
assert r.status_code != 403, f'{url} Access denied.'
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
code.write(r.content)
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
# check hash code
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
cur_hash_code = sha256_cal.hexdigest()[:8]
assert true_hash_code == cur_hash_code, f'{url} download failed, '
'incomplete downloaded file or url invalid.'
if cur_hash_code != true_hash_code:
os.remove(osp.join(collect_dir, checkpoint_name))
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument(
'-i', '--img', default='demo/demo.png', help='Image file')
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
parser.add_argument('-m', '--model-name', help='model name to inference')
parser.add_argument(
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def inference_model(config_name, checkpoint, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
if 'flip' in cfg.data.test.pipeline[
1] and 'img_scale' in cfg.data.test.pipeline[1]:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
result = inference_segmentor(model, args.img)
# show the results
if args.show:
show_result_pyplot(model, args.img, result)
return result
# Sample test whether the inference code is correct
def main(args):
config = Config.fromfile(args.config)
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root, 0o775)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, args)
except Exception:
print(f'{config_name} test failed!')
continue
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = get_root_logger(
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
for model_name in config:
model_infos = config[model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_path = model_info['config'].strip()
config_name = osp.splitext(osp.basename(config_path))[0]
checkpoint_name = model_info['checkpoint'].strip()
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
# ensure checkpoint exists
try:
if not osp.exists(checkpoint):
download_checkpoint(checkpoint_name, model_name,
config_name.rstrip('.py'),
args.checkpoint_root)
except Exception:
logger.error(f'{checkpoint_name} download error')
continue
# test model inference with checkpoint
try:
# build the model from a config file and a checkpoint file
inference_model(config_path, checkpoint, args, logger)
except Exception as e:
logger.error(f'{config_path} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)

파일 보기

@ -1,40 +0,0 @@
PARTITION=$1
echo 'configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr18s_512x512_160k_ade20k configs/hrnet/fcn_hr18s_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24727 --work-dir work_dirs/hrnet/fcn_hr18s_512x512_160k_ade20k >/dev/null &
echo 'configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr18s_512x1024_160k_cityscapes configs/hrnet/fcn_hr18s_512x1024_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24728 --work-dir work_dirs/hrnet/fcn_hr18s_512x1024_160k_cityscapes >/dev/null &
echo 'configs/hrnet/fcn_hr48_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr48_512x512_160k_ade20k configs/hrnet/fcn_hr48_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24729 --work-dir work_dirs/hrnet/fcn_hr48_512x512_160k_ade20k >/dev/null &
echo 'configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fcn_hr48_512x1024_160k_cityscapes configs/hrnet/fcn_hr48_512x1024_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24730 --work-dir work_dirs/hrnet/fcn_hr48_512x1024_160k_cityscapes >/dev/null &
echo 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r50-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24731 --work-dir work_dirs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes >/dev/null &
echo 'configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r101-d8_512x1024_80k_cityscapes configs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24732 --work-dir work_dirs/pspnet/pspnet_r101-d8_512x1024_80k_cityscapes >/dev/null &
echo 'configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r101-d8_512x512_160k_ade20k configs/pspnet/pspnet_r101-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24733 --work-dir work_dirs/pspnet/pspnet_r101-d8_512x512_160k_ade20k >/dev/null &
echo 'configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_r50-d8_512x512_160k_ade20k configs/pspnet/pspnet_r50-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24734 --work-dir work_dirs/pspnet/pspnet_r50-d8_512x512_160k_ade20k >/dev/null &
echo 'configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_s101-d8_512x512_160k_ade20k configs/resnest/pspnet_s101-d8_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24735 --work-dir work_dirs/resnest/pspnet_s101-d8_512x512_160k_ade20k >/dev/null &
echo 'configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION pspnet_s101-d8_512x1024_80k_cityscapes configs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24736 --work-dir work_dirs/resnest/pspnet_s101-d8_512x1024_80k_cityscapes >/dev/null &
echo 'configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION fast_scnn_lr0.12_8x4_160k_cityscapes configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24737 --work-dir work_dirs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes >/dev/null &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24738 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_769x769_80k_cityscapes >/dev/null &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24739 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes >/dev/null &
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r50-d8_512x1024_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24740 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes >/dev/null &
echo 'configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r50-d8_769x769_80k_cityscapes configs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24741 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r50-d8_769x769_80k_cityscapes >/dev/null &
echo 'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py' &
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_vit-b16_ln_mln_512x512_160k_ade20k configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24742 --work-dir work_dirs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k >/dev/null &
echo 'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py' &
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_deit-s16_ln_mln_512x512_160k_ade20k configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24743 --work-dir work_dirs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k >/dev/null &
echo 'configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py' &
GPUS=4 GPUS_PER_NODE=4 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes configs/deeplabv3plus/deeplabv3plus_r101-d8_fp16_512x1024_80k_cityscapes.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24744 --work-dir work_dirs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes >/dev/null &
echo 'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py' &
GPUS=8 GPUS_PER_NODE=8 CPUS_PER_TASK=2 ./tools/slurm_train.sh $PARTITION upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py --cfg-options checkpoint_config.max_keep_ckpts=1 dist_params.port=24745 --work-dir work_dirs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K >/dev/null &

파일 보기

@ -1,101 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
from argparse import ArgumentParser
import requests
import yaml as yml
from mmseg.utils import get_root_logger
def check_url(url):
"""Check url response status.
Args:
url (str): url needed to check.
Returns:
int, bool: status code and check flag.
"""
flag = True
r = requests.head(url)
status_code = r.status_code
if status_code == 403 or status_code == 404:
flag = False
return status_code, flag
def parse_args():
parser = ArgumentParser('url valid check.')
parser.add_argument(
'-m',
'--model-name',
type=str,
help='Select the model needed to check')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_name = args.model_name
# yml path generate.
# If model_name is not set, script will check all of the models.
if model_name is not None:
yml_list = [(model_name, f'configs/{model_name}/{model_name}.yml')]
else:
# check all
yml_list = [(x, f'configs/{x}/{x}.yml') for x in os.listdir('configs/')
if x != '_base_']
logger = get_root_logger(log_file='url_check.log', log_level=logging.ERROR)
for model_name, yml_path in yml_list:
# Default yaml loader unsafe.
model_infos = yml.load(
open(yml_path, 'r'), Loader=yml.CLoader)['Models']
for model_info in model_infos:
config_name = model_info['Name']
checkpoint_url = model_info['Weights']
# checkpoint url check
status_code, flag = check_url(checkpoint_url)
if flag:
logger.info(f'checkpoint | {config_name} | {checkpoint_url} | '
f'{status_code} valid')
else:
logger.error(
f'checkpoint | {config_name} | {checkpoint_url} | '
f'{status_code} | error')
# log_json check
checkpoint_name = checkpoint_url.split('/')[-1]
model_time = '-'.join(checkpoint_name.split('-')[:-1]).replace(
f'{config_name}_', '')
# two style of log_json name
# use '_' to link model_time (will be deprecated)
log_json_url_1 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json' # noqa
status_code_1, flag_1 = check_url(log_json_url_1)
# use '-' to link model_time
log_json_url_2 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json' # noqa
status_code_2, flag_2 = check_url(log_json_url_2)
if flag_1 or flag_2:
if flag_1:
logger.info(
f'log.json | {config_name} | {log_json_url_1} | '
f'{status_code_1} | valid')
else:
logger.info(
f'log.json | {config_name} | {log_json_url_2} | '
f'{status_code_2} | valid')
else:
logger.error(
f'log.json | {config_name} | {log_json_url_1} & '
f'{log_json_url_2} | {status_code_1} & {status_code_2} | '
'error')
if __name__ == '__main__':
main()

파일 보기

@ -1,91 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import os.path as osp
import mmcv
from mmcv import Config
def parse_args():
parser = argparse.ArgumentParser(
description='Gather benchmarked model evaluation results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
parser.add_argument(
'--out',
type=str,
default='benchmark_evaluation_info.json',
help='output path of gathered metrics and compared '
'results to be stored')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
root_path = args.root
metrics_out = args.out
result_dict = {}
cfg = Config.fromfile(args.config)
for model_key in cfg:
model_infos = cfg[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
previous_metrics = model_info['metric']
config = model_info['config'].strip()
fname, _ = osp.splitext(osp.basename(config))
# Load benchmark evaluation json
metric_json_dir = osp.join(root_path, fname)
if not osp.exists(metric_json_dir):
print(f'{metric_json_dir} not existed.')
continue
json_list = glob.glob(osp.join(metric_json_dir, '*.json'))
if len(json_list) == 0:
print(f'There is no eval json in {metric_json_dir}.')
continue
log_json_path = list(sorted(json_list))[-1]
metric = mmcv.load(log_json_path)
if config not in metric.get('config', {}):
print(f'{config} not included in {log_json_path}')
continue
# Compare between new benchmark results and previous metrics
differential_results = dict()
new_metrics = dict()
for record_metric_key in previous_metrics:
if record_metric_key not in metric['metric']:
raise KeyError('record_metric_key not exist, please '
'check your config')
old_metric = previous_metrics[record_metric_key]
new_metric = round(metric['metric'][record_metric_key] * 100,
2)
differential = new_metric - old_metric
flag = '+' if differential > 0 else '-'
differential_results[
record_metric_key] = f'{flag}{abs(differential):.2f}'
new_metrics[record_metric_key] = new_metric
result_dict[config] = dict(
differential=differential_results,
previous=previous_metrics,
new=new_metrics)
if metrics_out:
mmcv.dump(result_dict, metrics_out, indent=4)
print('===================================')
for config_name, metrics in result_dict.items():
print(config_name, metrics)
print('===================================')

파일 보기

@ -1,100 +0,0 @@
import argparse
import glob
import os.path as osp
import mmcv
from gather_models import get_final_results
from mmcv import Config
def parse_args():
parser = argparse.ArgumentParser(
description='Gather benchmarked models train results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
parser.add_argument(
'--out',
type=str,
default='benchmark_train_info.json',
help='output path of gathered metrics to be stored')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
root_path = args.root
metrics_out = args.out
evaluation_cfg = Config.fromfile(args.config)
result_dict = {}
for model_key in evaluation_cfg:
model_infos = evaluation_cfg[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config = model_info['config']
# benchmark train dir
model_name = osp.split(osp.dirname(config))[1]
config_name = osp.splitext(osp.basename(config))[0]
exp_dir = osp.join(root_path, model_name, config_name)
if not osp.exists(exp_dir):
print(f'{config} hasn\'t {exp_dir}')
continue
# parse config
cfg = mmcv.Config.fromfile(config)
total_iters = cfg.runner.max_iters
exp_metric = cfg.evaluation.metric
if not isinstance(exp_metric, list):
exp_metrics = [exp_metric]
# determine whether total_iters ckpt exists
ckpt_path = f'iter_{total_iters}.pth'
if not osp.exists(osp.join(exp_dir, ckpt_path)):
print(f'{config} hasn\'t {ckpt_path}')
continue
# only the last log json counts
log_json_path = list(
sorted(glob.glob(osp.join(exp_dir, '*.log.json'))))[-1]
# extract metric value
model_performance = get_final_results(log_json_path, total_iters)
if model_performance is None:
print(f'log file error: {log_json_path}')
continue
differential_results = dict()
old_results = dict()
new_results = dict()
for metric_key in model_performance:
if metric_key in ['mIoU']:
metric = round(model_performance[metric_key] * 100, 2)
old_metric = model_info['metric'][metric_key]
old_results[metric_key] = old_metric
new_results[metric_key] = metric
differential = metric - old_metric
flag = '+' if differential > 0 else '-'
differential_results[
metric_key] = f'{flag}{abs(differential):.2f}'
result_dict[config] = dict(
differential_results=differential_results,
old_results=old_results,
new_results=new_results,
)
# 4 save or print results
if metrics_out:
mmcv.dump(result_dict, metrics_out, indent=4)
print('===================================')
for config_name, metrics in result_dict.items():
print(config_name, metrics)
print('===================================')

파일 보기

@ -1,211 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import glob
import hashlib
import json
import os
import os.path as osp
import shutil
import mmcv
import torch
# build schedule look-up table to automatically find the final model
RESULTS_LUT = ['mIoU', 'mAcc', 'aAcc']
def calculate_file_sha256(file_path):
"""calculate file sha256 hash code."""
with open(file_path, 'rb') as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
return sha256_cal.hexdigest()
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
# The hash code calculation and rename command differ on different system
# platform.
sha = calculate_file_sha256(out_file)
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
os.rename(out_file, final_file)
# Remove prefix and suffix
final_file_name = osp.split(final_file)[1]
final_file_name = osp.splitext(final_file_name)[0]
return final_file_name
def get_final_iter(config):
iter_num = config.split('_')[-2]
assert iter_num.endswith('k')
return int(iter_num[:-1]) * 1000
def get_final_results(log_json_path, iter_num):
result_dict = dict()
last_iter = 0
with open(log_json_path, 'r') as f:
for line in f.readlines():
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue
# When evaluation, the 'iter' of new log json is the evaluation
# steps on single gpu.
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
if flag1 and flag2:
result_dict.update({
key: log_line[key]
for key in RESULTS_LUT if key in log_line
})
return result_dict
last_iter = log_line['iter']
def parse_args():
parser = argparse.ArgumentParser(description='Gather benchmarked models')
parser.add_argument(
'-f', '--config-name', type=str, help='Process the selected config.')
parser.add_argument(
'-w',
'--work-dir',
default='work_dirs/',
type=str,
help='Ckpt storage root folder of benchmarked models to be gathered.')
parser.add_argument(
'-c',
'--collect-dir',
default='work_dirs/gather',
type=str,
help='Ckpt collect root folder of gathered models.')
parser.add_argument(
'--all', action='store_true', help='whether include .py and .log')
args = parser.parse_args()
return args
def main():
args = parse_args()
work_dir = args.work_dir
collect_dir = args.collect_dir
selected_config_name = args.config_name
mmcv.mkdir_or_exist(collect_dir)
# find all models in the root directory to be gathered
raw_configs = list(mmcv.scandir('./configs', '.py', recursive=True))
# filter configs that is not trained in the experiments dir
used_configs = []
for raw_config in raw_configs:
config_name = osp.splitext(osp.basename(raw_config))[0]
if osp.exists(osp.join(work_dir, config_name)):
if (selected_config_name is None
or selected_config_name == config_name):
used_configs.append(raw_config)
print(f'Find {len(used_configs)} models to be gathered')
# find final_ckpt and log file for trained each config
# and parse the best performance
model_infos = []
for used_config in used_configs:
config_name = osp.splitext(osp.basename(used_config))[0]
exp_dir = osp.join(work_dir, config_name)
# check whether the exps is finished
final_iter = get_final_iter(used_config)
final_model = 'iter_{}.pth'.format(final_iter)
model_path = osp.join(exp_dir, final_model)
# skip if the model is still training
if not osp.exists(model_path):
print(f'{used_config} train not finished yet')
continue
# get logs
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
log_json_path = log_json_paths[0]
model_performance = None
for idx, _log_json_path in enumerate(log_json_paths):
model_performance = get_final_results(_log_json_path, final_iter)
if model_performance is not None:
log_json_path = _log_json_path
break
if model_performance is None:
print(f'{used_config} model_performance is None')
continue
model_time = osp.split(log_json_path)[-1].split('.')[0]
model_infos.append(
dict(
config_name=config_name,
results=model_performance,
iters=final_iter,
model_time=model_time,
log_json_path=osp.split(log_json_path)[-1]))
# publish model for each checkpoint
publish_model_infos = []
for model in model_infos:
config_name = model['config_name']
model_publish_dir = osp.join(collect_dir, config_name)
publish_model_path = osp.join(model_publish_dir,
config_name + '_' + model['model_time'])
trained_model_path = osp.join(work_dir, config_name,
'iter_{}.pth'.format(model['iters']))
if osp.exists(model_publish_dir):
for file in os.listdir(model_publish_dir):
if file.endswith('.pth'):
print(f'model {file} found')
model['model_path'] = osp.abspath(
osp.join(model_publish_dir, file))
break
if 'model_path' not in model:
print(f'dir {model_publish_dir} exists, no model found')
else:
mmcv.mkdir_or_exist(model_publish_dir)
# convert model
final_model_path = process_checkpoint(trained_model_path,
publish_model_path)
model['model_path'] = final_model_path
new_json_path = f'{config_name}_{model["log_json_path"]}'
# copy log
shutil.copy(
osp.join(work_dir, config_name, model['log_json_path']),
osp.join(model_publish_dir, new_json_path))
if args.all:
new_txt_path = new_json_path.rstrip('.json')
shutil.copy(
osp.join(work_dir, config_name,
model['log_json_path'].rstrip('.json')),
osp.join(model_publish_dir, new_txt_path))
if args.all:
# copy config to guarantee reproducibility
raw_config = osp.join('./configs', f'{config_name}.py')
mmcv.Config.fromfile(raw_config).dump(
osp.join(model_publish_dir, osp.basename(raw_config)))
publish_model_infos.append(model)
models = dict(models=publish_model_infos)
mmcv.dump(models, osp.join(collect_dir, 'model_infos.json'), indent=4)
if __name__ == '__main__':
main()

파일 보기

@ -1,114 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from mmcv import Config
def parse_args():
parser = argparse.ArgumentParser(
description='Convert benchmark test model list to script')
parser.add_argument('config', help='test config file path')
parser.add_argument('--port', type=int, default=28171, help='dist port')
parser.add_argument(
'--work-dir',
default='work_dirs/benchmark_evaluation',
help='the dir to save metric')
parser.add_argument(
'--out',
type=str,
default='.dev/benchmark_evaluation.sh',
help='path to save model benchmark script')
args = parser.parse_args()
return args
def process_model_info(model_info, work_dir):
config = model_info['config'].strip()
fname, _ = osp.splitext(osp.basename(config))
job_name = fname
checkpoint = model_info['checkpoint'].strip()
work_dir = osp.join(work_dir, fname)
if not isinstance(model_info['eval'], list):
evals = [model_info['eval']]
else:
evals = model_info['eval']
eval = ' '.join(evals)
return dict(
config=config,
job_name=job_name,
checkpoint=checkpoint,
work_dir=work_dir,
eval=eval)
def create_test_bash_info(commands, model_test_dict, port, script_name,
partition):
config = model_test_dict['config']
job_name = model_test_dict['job_name']
checkpoint = model_test_dict['checkpoint']
work_dir = model_test_dict['work_dir']
eval = model_test_dict['eval']
echo_info = f'\necho \'{config}\' &'
commands.append(echo_info)
commands.append('\n')
command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \
f'CPUS_PER_TASK=2 {script_name} '
command_info += f'{partition} '
command_info += f'{job_name} '
command_info += f'{config} '
command_info += f'$CHECKPOINT_DIR/{checkpoint} '
command_info += f'--eval {eval} '
command_info += f'--work-dir {work_dir} '
command_info += f'--cfg-options dist_params.port={port} '
command_info += '&'
commands.append(command_info)
def main():
args = parse_args()
if args.out:
out_suffix = args.out.split('.')[-1]
assert args.out.endswith('.sh'), \
f'Expected out file path suffix is .sh, but get .{out_suffix}'
commands = []
partition_name = 'PARTITION=$1'
commands.append(partition_name)
commands.append('\n')
checkpoint_root = 'CHECKPOINT_DIR=$2'
commands.append(checkpoint_root)
commands.append('\n')
script_name = osp.join('tools', 'slurm_test.sh')
port = args.port
work_dir = args.work_dir
cfg = Config.fromfile(args.config)
for model_key in cfg:
model_infos = cfg[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'])
model_test_dict = process_model_info(model_info, work_dir)
create_test_bash_info(commands, model_test_dict, port, script_name,
'$PARTITION')
port += 1
command_str = ''.join(commands)
if args.out:
with open(args.out, 'w') as f:
f.write(command_str + '\n')
if __name__ == '__main__':
main()

파일 보기

@ -1,91 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
# Default using 4 gpu when training
config_8gpu_list = [
'configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K.py', # noqa
'configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py',
'configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py',
]
def parse_args():
parser = argparse.ArgumentParser(
description='Convert benchmark model json to script')
parser.add_argument(
'txt_path', type=str, help='txt path output by benchmark_filter')
parser.add_argument('--port', type=int, default=24727, help='dist port')
parser.add_argument(
'--out',
type=str,
default='.dev/benchmark_train.sh',
help='path to save model benchmark script')
args = parser.parse_args()
return args
def create_train_bash_info(commands, config, script_name, partition, port):
cfg = config.strip()
# print cfg name
echo_info = f'echo \'{cfg}\' &'
commands.append(echo_info)
commands.append('\n')
_, model_name = osp.split(osp.dirname(cfg))
config_name, _ = osp.splitext(osp.basename(cfg))
# default setting
if cfg in config_8gpu_list:
command_info = f'GPUS=8 GPUS_PER_NODE=8 ' \
f'CPUS_PER_TASK=2 {script_name} '
else:
command_info = f'GPUS=4 GPUS_PER_NODE=4 ' \
f'CPUS_PER_TASK=2 {script_name} '
command_info += f'{partition} '
command_info += f'{config_name} '
command_info += f'{cfg} '
command_info += f'--cfg-options ' \
f'checkpoint_config.max_keep_ckpts=1 ' \
f'dist_params.port={port} '
command_info += f'--work-dir work_dirs/{model_name}/{config_name} '
# Let the script shut up
command_info += '>/dev/null &'
commands.append(command_info)
commands.append('\n')
def main():
args = parse_args()
if args.out:
out_suffix = args.out.split('.')[-1]
assert args.out.endswith('.sh'), \
f'Expected out file path suffix is .sh, but get .{out_suffix}'
root_name = './tools'
script_name = osp.join(root_name, 'slurm_train.sh')
port = args.port
partition_name = 'PARTITION=$1'
commands = []
commands.append(partition_name)
commands.append('\n')
commands.append('\n')
with open(args.txt_path, 'r') as f:
model_cfgs = f.readlines()
for i, cfg in enumerate(model_cfgs):
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
port)
port += 1
command_str = ''.join(commands)
if args.out:
with open(args.out, 'w') as f:
f.write(command_str)
if __name__ == '__main__':
main()

파일 보기

@ -1,18 +0,0 @@
work_dir = '../../work_dirs'
metric = 'mIoU'
# specify the log files we would like to collect in `log_items`
log_items = [
'segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup',
'segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr',
'segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr',
'segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr'
]
# or specify ignore_keywords, then the folders whose name contain
# `'segformer'` won't be collected
# ignore_keywords = ['segformer']
# should not include metric
other_info_keys = ['mAcc']
markdown_file = 'markdowns/lr_in_trans.json.md'
json_file = 'jsons/trans_in_cnn.json'

파일 보기

@ -1,143 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import datetime
import json
import os
import os.path as osp
from collections import OrderedDict
from utils import load_config
# automatically collect all the results
# The structure of the directory:
# ├── work-dir
# │ ├── config_1
# │ │ ├── time1.log.json
# │ │ ├── time2.log.json
# │ │ ├── time3.log.json
# │ │ ├── time4.log.json
# │ ├── config_2
# │ │ ├── time5.log.json
# │ │ ├── time6.log.json
# │ │ ├── time7.log.json
# │ │ ├── time8.log.json
def parse_args():
parser = argparse.ArgumentParser(description='extract info from log.json')
parser.add_argument('config_dir')
args = parser.parse_args()
return args
def has_keyword(name: str, keywords: list):
for a_keyword in keywords:
if a_keyword in name:
return True
return False
def main():
args = parse_args()
cfg = load_config(args.config_dir)
work_dir = cfg['work_dir']
metric = cfg['metric']
log_items = cfg.get('log_items', [])
ignore_keywords = cfg.get('ignore_keywords', [])
other_info_keys = cfg.get('other_info_keys', [])
markdown_file = cfg.get('markdown_file', None)
json_file = cfg.get('json_file', None)
if json_file and osp.split(json_file)[0] != '':
os.makedirs(osp.split(json_file)[0], exist_ok=True)
if markdown_file and osp.split(markdown_file)[0] != '':
os.makedirs(osp.split(markdown_file)[0], exist_ok=True)
assert not (log_items and ignore_keywords), \
'log_items and ignore_keywords cannot be specified at the same time'
assert metric not in other_info_keys, \
'other_info_keys should not contain metric'
if ignore_keywords and isinstance(ignore_keywords, str):
ignore_keywords = [ignore_keywords]
if other_info_keys and isinstance(other_info_keys, str):
other_info_keys = [other_info_keys]
if log_items and isinstance(log_items, str):
log_items = [log_items]
if not log_items:
log_items = [
item for item in sorted(os.listdir(work_dir))
if not has_keyword(item, ignore_keywords)
]
experiment_info_list = []
for config_dir in log_items:
preceding_path = os.path.join(work_dir, config_dir)
log_list = [
item for item in os.listdir(preceding_path)
if item.endswith('.log.json')
]
log_list = sorted(
log_list,
key=lambda time_str: datetime.datetime.strptime(
time_str, '%Y%m%d_%H%M%S.log.json'))
val_list = []
last_iter = 0
for log_name in log_list:
with open(os.path.join(preceding_path, log_name), 'r') as f:
# ignore the info line
f.readline()
all_lines = f.readlines()
val_list.extend([
json.loads(line) for line in all_lines
if json.loads(line)['mode'] == 'val'
])
for index in range(len(all_lines) - 1, -1, -1):
line_dict = json.loads(all_lines[index])
if line_dict['mode'] == 'train':
last_iter = max(last_iter, line_dict['iter'])
break
new_log_dict = dict(
method=config_dir, metric_used=metric, last_iter=last_iter)
for index, log in enumerate(val_list, 1):
new_ordered_dict = OrderedDict()
new_ordered_dict['eval_index'] = index
new_ordered_dict[metric] = log[metric]
for key in other_info_keys:
if key in log:
new_ordered_dict[key] = log[key]
val_list[index - 1] = new_ordered_dict
assert len(val_list) >= 1, \
f"work dir {config_dir} doesn't contain any evaluation."
new_log_dict['last eval'] = val_list[-1]
new_log_dict['best eval'] = max(val_list, key=lambda x: x[metric])
experiment_info_list.append(new_log_dict)
print(f'{config_dir} is processed')
if json_file:
with open(json_file, 'w') as f:
json.dump(experiment_info_list, f, indent=4)
if markdown_file:
lines_to_write = []
for index, log in enumerate(experiment_info_list, 1):
lines_to_write.append(
f"|{index}|{log['method']}|{log['best eval'][metric]}"
f"|{log['best eval']['eval_index']}|"
f"{log['last eval'][metric]}|"
f"{log['last eval']['eval_index']}|{log['last_iter']}|\n")
with open(markdown_file, 'w') as f:
f.write(f'|exp_num|method|{metric} best|best index|'
f'{metric} last|last index|last iter num|\n')
f.write('|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n')
f.writelines(lines_to_write)
print('processed successfully')
if __name__ == '__main__':
main()

파일 보기

@ -1,144 +0,0 @@
# Log Collector
## Function
Automatically collect logs and write the result in a json file or markdown file.
If there are several `.log.json` files in one folder, Log Collector assumes that the `.log.json` files other than the first one are resume from the preceding `.log.json` file. Log Collector returns the result considering all `.log.json` files.
## Usage:
To use log collector, you need to write a config file to configure the log collector first.
For example:
example_config.py:
```python
# The work directory that contains folders that contains .log.json files.
work_dir = '../../work_dirs'
# The metric used to find the best evaluation.
metric = 'mIoU'
# **Don't specify the log_items and ignore_keywords at the same time.**
# Specify the log files we would like to collect in `log_items`.
# The folders specified should be the subdirectories of `work_dir`.
log_items = [
'segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup',
'segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr',
'segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr',
'segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr'
]
# Or specify `ignore_keywords`. The folders whose name contain one
# of the keywords in the `ignore_keywords` list(e.g., `'segformer'`)
# won't be collected.
# ignore_keywords = ['segformer']
# Other log items in .log.json that you want to collect.
# should not include metric.
other_info_keys = ["mAcc"]
# The output markdown file's name.
markdown_file ='markdowns/lr_in_trans.json.md'
# The output json file's name. (optional)
json_file = 'jsons/trans_in_cnn.json'
```
The structure of the work-dir directory should be like
```text
├── work-dir
│ ├── folder1
│ │ ├── time1.log.json
│ │ ├── time2.log.json
│ │ ├── time3.log.json
│ │ ├── time4.log.json
│ ├── folder2
│ │ ├── time5.log.json
│ │ ├── time6.log.json
│ │ ├── time7.log.json
│ │ ├── time8.log.json
```
Then , cd to the log collector folder.
Now you can run log_collector.py by using command:
```bash
python log_collector.py ./example_config.py
```
The output markdown file is like:
| exp_num | method | mIoU best | best index | mIoU last | last index | last iter num |
| :-----: | :-----------------------------------------------------: | :-------: | :--------: | :-------: | :--------: | :-----------: |
| 1 | segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup | 0.2776 | 10 | 0.2776 | 10 | 160000 |
| 2 | segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr | 0.2802 | 10 | 0.2802 | 10 | 160000 |
| 3 | segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr | 0.4943 | 11 | 0.4943 | 11 | 160000 |
| 4 | segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr | 0.4883 | 11 | 0.4883 | 11 | 160000 |
The output json file is like:
```json
[
{
"method": "segformer_mit-b5_512x512_160k_ade20k_cnn_lr_with_warmup",
"metric_used": "mIoU",
"last_iter": 160000,
"last eval": {
"eval_index": 10,
"mIoU": 0.2776,
"mAcc": 0.3779
},
"best eval": {
"eval_index": 10,
"mIoU": 0.2776,
"mAcc": 0.3779
}
},
{
"method": "segformer_mit-b5_512x512_160k_ade20k_cnn_no_warmup_lr",
"metric_used": "mIoU",
"last_iter": 160000,
"last eval": {
"eval_index": 10,
"mIoU": 0.2802,
"mAcc": 0.3764
},
"best eval": {
"eval_index": 10,
"mIoU": 0.2802,
"mAcc": 0.3764
}
},
{
"method": "segformer_mit-b5_512x512_160k_ade20k_mit_trans_lr",
"metric_used": "mIoU",
"last_iter": 160000,
"last eval": {
"eval_index": 11,
"mIoU": 0.4943,
"mAcc": 0.6097
},
"best eval": {
"eval_index": 11,
"mIoU": 0.4943,
"mAcc": 0.6097
}
},
{
"method": "segformer_mit-b5_512x512_160k_ade20k_swin_trans_lr",
"metric_used": "mIoU",
"last_iter": 160000,
"last eval": {
"eval_index": 11,
"mIoU": 0.4883,
"mAcc": 0.6061
},
"best eval": {
"eval_index": 11,
"mIoU": 0.4883,
"mAcc": 0.6061
}
}
]
```

파일 보기

@ -1,20 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
# modified from https://github.dev/open-mmlab/mmcv
import os.path as osp
import sys
from importlib import import_module
def load_config(cfg_dir: str) -> dict:
assert cfg_dir.endswith('.py')
root_path, file_name = osp.split(cfg_dir)
temp_module = osp.splitext(file_name)[0]
sys.path.insert(0, root_path)
mod = import_module(temp_module)
sys.path.pop(0)
cfg_dict = {
k: v
for k, v in mod.__dict__.items() if not k.startswith('__')
}
del sys.modules[temp_module]
return cfg_dict

파일 보기

@ -1,317 +0,0 @@
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
# This tool is used to update model-index.yml which is required by MIM, and
# will be automatically called as a pre-commit hook. The updating will be
# triggered if any change of model information (.md files in configs/) has been
# detected before a commit.
import glob
import os
import os.path as osp
import re
import sys
from lxml import etree
from mmcv.fileio import dump
MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__))))
COLLECTIONS = [
'ANN', 'APCNet', 'BiSeNetV1', 'BiSeNetV2', 'CCNet', 'CGNet', 'DANet',
'DeepLabV3', 'DeepLabV3+', 'DMNet', 'DNLNet', 'DPT', 'EMANet', 'EncNet',
'ERFNet', 'FastFCN', 'FastSCNN', 'FCN', 'GCNet', 'ICNet', 'ISANet', 'KNet',
'NonLocalNet', 'OCRNet', 'PointRend', 'PSANet', 'PSPNet', 'Segformer',
'Segmenter', 'FPN', 'SETR', 'STDC', 'UNet', 'UPerNet'
]
COLLECTIONS_TEMP = []
def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
"""Dump object to a yaml file, and check if the file content is different
from the original.
Args:
obj (any): The python object to be dumped.
filename (str): YAML filename to dump the object to.
sort_keys (str); Sort key by dictionary order.
Returns:
Bool: If the target YAML file is different from the original.
"""
str_dump = dump(obj, None, file_format='yaml', sort_keys=sort_keys)
if osp.isfile(filename):
file_exists = True
with open(filename, 'r', encoding='utf-8') as f:
str_orig = f.read()
else:
file_exists = False
str_orig = None
if file_exists and str_orig == str_dump:
is_different = False
else:
is_different = True
with open(filename, 'w', encoding='utf-8') as f:
f.write(str_dump)
return is_different
def parse_md(md_file):
"""Parse .md file and convert it to a .yml file which can be used for MIM.
Args:
md_file (str): Path to .md file.
Returns:
Bool: If the target YAML file is different from the original.
"""
collection_name = osp.split(osp.dirname(md_file))[1]
configs = os.listdir(osp.dirname(md_file))
collection = dict(
Name=collection_name,
Metadata={'Training Data': []},
Paper={
'URL': '',
'Title': ''
},
README=md_file,
Code={
'URL': '',
'Version': ''
})
collection.update({'Converted From': {'Weights': '', 'Code': ''}})
models = []
datasets = []
paper_url = None
paper_title = None
code_url = None
code_version = None
repo_url = None
# To avoid re-counting number of backbone model in OpenMMLab,
# if certain model in configs folder is backbone whose name is already
# recorded in MMClassification, then the `COLLECTION` dict of this model
# in MMSegmentation should be deleted, and `In Collection` in `Models`
# should be set with head or neck of this config file.
is_backbone = None
with open(md_file, 'r', encoding='UTF-8') as md:
lines = md.readlines()
i = 0
current_dataset = ''
while i < len(lines):
line = lines[i].strip()
# In latest README.md the title and url are in the third line.
if i == 2:
paper_url = lines[i].split('](')[1].split(')')[0]
paper_title = lines[i].split('](')[0].split('[')[1]
if len(line) == 0:
i += 1
continue
elif line[:3] == '<a ':
content = etree.HTML(line)
node = content.xpath('//a')[0]
if node.text == 'Code Snippet':
code_url = node.get('href', None)
assert code_url is not None, (
f'{collection_name} hasn\'t code snippet url.')
# version extraction
filter_str = r'blob/(.*)/mm'
pattern = re.compile(filter_str)
code_version = pattern.findall(code_url)
assert len(code_version) == 1, (
f'false regular expression ({filter_str}) use.')
code_version = code_version[0]
elif node.text == 'Official Repo':
repo_url = node.get('href', None)
assert repo_url is not None, (
f'{collection_name} hasn\'t official repo url.')
i += 1
elif line[:4] == '### ':
datasets.append(line[4:])
current_dataset = line[4:]
i += 2
elif line[:15] == '<!-- [BACKBONE]':
is_backbone = True
i += 1
elif (line[0] == '|' and (i + 1) < len(lines)
and lines[i + 1][:3] == '| -' and 'Method' in line
and 'Crop Size' in line and 'Mem (GB)' in line):
cols = [col.strip() for col in line.split('|')]
method_id = cols.index('Method')
backbone_id = cols.index('Backbone')
crop_size_id = cols.index('Crop Size')
lr_schd_id = cols.index('Lr schd')
mem_id = cols.index('Mem (GB)')
fps_id = cols.index('Inf time (fps)')
try:
ss_id = cols.index('mIoU')
except ValueError:
ss_id = cols.index('Dice')
try:
ms_id = cols.index('mIoU(ms+flip)')
except ValueError:
ms_id = False
config_id = cols.index('config')
download_id = cols.index('download')
j = i + 2
while j < len(lines) and lines[j][0] == '|':
els = [el.strip() for el in lines[j].split('|')]
config = ''
model_name = ''
weight = ''
for fn in configs:
if fn in els[config_id]:
left = els[download_id].index(
'https://download.openmmlab.com')
right = els[download_id].index('.pth') + 4
weight = els[download_id][left:right]
config = f'configs/{collection_name}/{fn}'
model_name = fn[:-3]
fps = els[fps_id] if els[fps_id] != '-' and els[
fps_id] != '' else -1
mem = els[mem_id].split(
'\\'
)[0] if els[mem_id] != '-' and els[mem_id] != '' else -1
crop_size = els[crop_size_id].split('x')
assert len(crop_size) == 2
method = els[method_id].split()[0].split('-')[-1]
model = {
'Name':
model_name,
'In Collection':
method,
'Metadata': {
'backbone': els[backbone_id],
'crop size': f'({crop_size[0]},{crop_size[1]})',
'lr schd': int(els[lr_schd_id]),
},
'Results': [
{
'Task': 'Semantic Segmentation',
'Dataset': current_dataset,
'Metrics': {
cols[ss_id]: float(els[ss_id]),
},
},
],
'Config':
config,
'Weights':
weight,
}
if fps != -1:
try:
fps = float(fps)
except Exception:
j += 1
continue
model['Metadata']['inference time (ms/im)'] = [{
'value':
round(1000 / float(fps), 2),
'hardware':
'V100',
'backend':
'PyTorch',
'batch size':
1,
'mode':
'FP32' if 'fp16' not in config else 'FP16',
'resolution':
f'({crop_size[0]},{crop_size[1]})'
}]
if mem != -1:
model['Metadata']['Training Memory (GB)'] = float(mem)
# Only have semantic segmentation now
if ms_id and els[ms_id] != '-' and els[ms_id] != '':
model['Results'][0]['Metrics'][
'mIoU(ms+flip)'] = float(els[ms_id])
models.append(model)
j += 1
i = j
else:
i += 1
flag = (code_url is not None) and (paper_url is not None) and (repo_url
is not None)
assert flag, f'{collection_name} readme error'
collection['Name'] = method
collection['Metadata']['Training Data'] = datasets
collection['Code']['URL'] = code_url
collection['Code']['Version'] = code_version
collection['Paper']['URL'] = paper_url
collection['Paper']['Title'] = paper_title
collection['Converted From']['Code'] = repo_url
# ['Converted From']['Weights] miss
# remove empty attribute
check_key_list = ['Code', 'Paper', 'Converted From']
for check_key in check_key_list:
key_list = list(collection[check_key].keys())
for key in key_list:
if check_key not in collection:
break
if collection[check_key][key] == '':
if len(collection[check_key].keys()) == 1:
collection.pop(check_key)
else:
collection[check_key].pop(key)
yml_file = f'{md_file[:-9]}{collection_name}.yml'
if is_backbone:
if collection['Name'] not in COLLECTIONS:
result = {
'Collections': [collection],
'Models': models,
'Yml': yml_file
}
COLLECTIONS_TEMP.append(result)
return False
else:
result = {'Models': models}
else:
COLLECTIONS.append(collection['Name'])
result = {'Collections': [collection], 'Models': models}
return dump_yaml_and_check_difference(result, yml_file)
def update_model_index():
"""Update model-index.yml according to model .md files.
Returns:
Bool: If the updated model-index.yml is different from the original.
"""
configs_dir = osp.join(MMSEG_ROOT, 'configs')
yml_files = glob.glob(osp.join(configs_dir, '**', '*.yml'), recursive=True)
yml_files.sort()
# add .replace('\\', '/') to avoid Windows Style path
model_index = {
'Import': [
osp.relpath(yml_file, MMSEG_ROOT).replace('\\', '/')
for yml_file in yml_files
]
}
model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
is_different = dump_yaml_and_check_difference(model_index,
model_index_file)
return is_different
if __name__ == '__main__':
file_list = [fn for fn in sys.argv[1:] if osp.basename(fn) == 'README.md']
if not file_list:
sys.exit(0)
file_modified = False
for fn in file_list:
file_modified |= parse_md(fn)
for result in COLLECTIONS_TEMP:
collection = result['Collections'][0]
yml_file = result.pop('Yml', None)
if collection['Name'] in COLLECTIONS:
result.pop('Collections')
file_modified |= dump_yaml_and_check_difference(result, yml_file)
file_modified |= update_model_index()
sys.exit(1 if file_modified else 0)

파일 보기

@ -1,45 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import oss2
ACCESS_KEY_ID = os.getenv('OSS_ACCESS_KEY_ID', None)
ACCESS_KEY_SECRET = os.getenv('OSS_ACCESS_KEY_SECRET', None)
BUCKET_NAME = 'openmmlab'
ENDPOINT = 'https://oss-accelerate.aliyuncs.com'
def parse_args():
parser = argparse.ArgumentParser(description='Upload models to OSS')
parser.add_argument('model_zoo', type=str, help='model_zoo input')
parser.add_argument(
'--dst-folder',
type=str,
default='mmsegmentation/v0.5',
help='destination folder')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_zoo = args.model_zoo
dst_folder = args.dst_folder
bucket = oss2.Bucket(
oss2.Auth(ACCESS_KEY_ID, ACCESS_KEY_SECRET), ENDPOINT, BUCKET_NAME)
for root, dirs, files in os.walk(model_zoo):
for file in files:
file_path = osp.relpath(osp.join(root, file), model_zoo)
print(f'Uploading {file_path}')
oss2.resumable_upload(bucket, osp.join(dst_folder, file_path),
osp.join(model_zoo, file_path))
bucket.put_object_acl(
osp.join(dst_folder, file_path), oss2.OBJECT_ACL_PUBLIC_READ)
if __name__ == '__main__':
main()

파일 보기

@ -1,76 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
- Using welcoming and inclusive language
- Being respectful of differing viewpoints and experiences
- Gracefully accepting constructive criticism
- Focusing on what is best for the community
- Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
- The use of sexualized language or imagery and unwelcome sexual attention or
advances
- Trolling, insulting/derogatory comments, and personal or political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or electronic
address, without explicit permission
- Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at chenkaidev@gmail.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
[homepage]: https://www.contributor-covenant.org

파일 보기

@ -1,58 +0,0 @@
# Contributing to mmsegmentation
All kinds of contributions are welcome, including but not limited to the following.
- Fixes (typo, bugs)
- New features and components
## Workflow
1. fork and pull the latest mmsegmentation
2. checkout a new branch (do not use master branch for PRs)
3. commit your changes
4. create a PR
:::{note}
- If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
- If you are the author of some papers and would like to include your method to mmsegmentation,
please contact Kai Chen (chenkaidev\[at\]gmail\[dot\]com). We will much appreciate your contribution.
:::
## Code style
### Python
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
We use the following tools for linting and formatting:
- [flake8](http://flake8.pycqa.org/en/latest/): linter
- [yapf](https://github.com/google/yapf): formatter
- [isort](https://github.com/timothycrosley/isort): sort imports
Style configurations of yapf and isort can be found in [setup.cfg](../setup.cfg) and [.isort.cfg](../.isort.cfg).
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`,
fixes `end-of-files`, sorts `requirments.txt` automatically on every commit.
The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).
After you clone the repository, you will need to install initialize pre-commit hook.
```shell
pip install -U pre-commit
```
From the repository folder
```shell
pre-commit install
```
After this on every commit check code linters and formatter will be enforced.
> Before you create a PR, make sure that your code lints and is formatted by yapf.
### C++ and CUDA
We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).

파일 보기

@ -1,6 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: MMSegmentation Documentation
url: https://mmsegmentation.readthedocs.io
about: Check the docs and FAQ to see if you question is already answered.

파일 보기

@ -1,48 +0,0 @@
---
name: Error report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
Thanks for your error report and we appreciate it a lot.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The bug has not been fixed in the latest version.
**Describe the bug**
A clear and concise description of what the bug is.
**Reproduction**
1. What command or script did you run?
```none
A placeholder for the command.
```
2. Did you make any modifications on the code or config? Did you understand what you have modified?
3. What dataset did you use?
**Environment**
1. Please run `python mmseg/utils/collect_env.py` to collect necessary environment information and paste it here.
2. You may add addition that may be helpful for locating the problem, such as
- How you installed PyTorch \[e.g., pip, conda, source\]
- Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Error traceback**
If applicable, paste the error trackback here.
```none
A placeholder for trackback.
```
**Bug fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

파일 보기

@ -1,21 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
# Describe the feature
**Motivation**
A clear and concise description of the motivation of the feature.
Ex1. It is inconvenient when \[....\].
Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
**Related resources**
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
**Additional context**
Add any other context or screenshots about the feature request here.
If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.

파일 보기

@ -1,7 +0,0 @@
---
name: General questions
about: Ask general questions to get help
title: ''
labels: ''
assignees: ''
---

파일 보기

@ -1,69 +0,0 @@
---
name: Reimplementation Questions
about: Ask about questions during model reimplementation
title: ''
labels: reimplementation
assignees: ''
---
If you feel we have helped you, give us a STAR! :satisfied:
**Notice**
There are several common situations in the reimplementation issues as below
1. Reimplement a model in the model zoo using the provided configs
2. Reimplement a model in the model zoo on other datasets (e.g., custom datasets)
3. Reimplement a custom model but all the components are implemented in MMSegmentation
4. Reimplement a custom model with new modules implemented by yourself
There are several things to do for different cases as below.
- For cases 1 & 3, please follow the steps in the following sections thus we could help to quickly identify the issue.
- For cases 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code, and the users should be responsible for the code they write.
- One suggestion for cases 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections, and try as clear as possible so that we can better help you.
**Checklist**
1. I have searched related issues but cannot get the expected help.
2. The issue has not been fixed in the latest version.
**Describe the issue**
A clear and concise description of the problem you meet and what you have done.
**Reproduction**
1. What command or script did you run?
```
A placeholder for the command.
```
2. What config dir you run?
```
A placeholder for the config.
```
3. Did you make any modifications to the code or config? Did you understand what you have modified?
4. What dataset did you use?
**Environment**
1. Please run `PYTHONPATH=${PWD}:$PYTHONPATH python mmseg/utils/collect_env.py` to collect the necessary environment information and paste it here.
2. You may add an addition that may be helpful for locating the problem, such as
1. How you installed PyTorch \[e.g., pip, conda, source\]
2. Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Results**
If applicable, paste the related results here, e.g., what you expect and what you get.
```
A placeholder for results comparison
```
**Issue fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

파일 보기

@ -1,25 +0,0 @@
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
## Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
## Modification
Please briefly describe what modification is made in this PR.
## BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
## Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
## Checklist
1. Pre-commit or other linting tools are used to fix the potential lint issues.
2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D.
4. The documentation has been modified accordingly, like docstring or example tutorials.

파일 보기

@ -1,257 +0,0 @@
name: build
on:
push:
paths-ignore:
- 'demo/**'
- '.dev/**'
- 'docker/**'
- 'tools/**'
- '**.md'
pull_request:
paths-ignore:
- 'demo/**'
- '.dev/**'
- 'docker/**'
- 'tools/**'
- 'docs/**'
- '**.md'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build_cpu:
runs-on: ubuntu-18.04
strategy:
matrix:
python-version: [3.7]
torch: [1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
include:
- torch: 1.5.1
torch_version: torch1.5
torchvision: 0.6.1
- torch: 1.6.0
torch_version: torch1.6
torchvision: 0.7.0
- torch: 1.7.0
torch_version: torch1.7
torchvision: 0.8.1
- torch: 1.8.0
torch_version: torch1.8
torchvision: 0.9.0
- torch: 1.9.0
torch_version: torch1.9
torchvision: 0.10.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install MMCV
run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/${{matrix.torch_version}}/index.html
python -c 'import mmcv; print(mmcv.__version__)'
- name: Install unittest dependencies
run: |
pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
build_cuda101:
runs-on: ubuntu-18.04
container:
image: pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel
strategy:
matrix:
python-version: [3.7]
torch:
[
1.5.1+cu101,
1.6.0+cu101,
1.7.0+cu101,
1.8.0+cu101
]
include:
- torch: 1.5.1+cu101
torch_version: torch1.5
torchvision: 0.6.1+cu101
- torch: 1.6.0+cu101
torch_version: torch1.6
torchvision: 0.7.0+cu101
- torch: 1.7.0+cu101
torch_version: torch1.7
torchvision: 0.8.1+cu101
- torch: 1.8.0+cu101
torch_version: torch1.8
torchvision: 0.9.0+cu101
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Fetch GPG keys
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: |
apt-get update && apt-get install -y libgl1-mesa-glx ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 python${{matrix.python-version}}-dev
apt-get clean
rm -rf /var/lib/apt/lists/*
- name: Install Pillow
run: python -m pip install Pillow==6.2.2
if: ${{matrix.torchvision < 0.5}}
- name: Install PyTorch
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmseg dependencies
run: |
python -V
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -c 'import mmcv; print(mmcv.__version__)'
- name: Build and install
run: |
rm -rf .eggs
python setup.py check -m -s
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.10
with:
file: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_cuda102:
runs-on: ubuntu-18.04
container:
image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
torch: [1.9.0+cu102]
include:
- torch: 1.9.0+cu102
torch_version: torch1.9
torchvision: 0.10.0+cu102
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Fetch GPG keys
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: |
apt-get update && apt-get install -y libgl1-mesa-glx ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
apt-get clean
rm -rf /var/lib/apt/lists/*
- name: Install Pillow
run: python -m pip install Pillow==6.2.2
if: ${{matrix.torchvision < 0.5}}
- name: Install PyTorch
run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmseg dependencies
run: |
python -V
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu102/${{matrix.torch_version}}/index.html
python -m pip install -r requirements.txt
python -c 'import mmcv; print(mmcv.__version__)'
- name: Build and install
run: |
rm -rf .eggs
python setup.py check -m -s
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
files: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
test_windows:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-2022]
python: [3.8]
platform: [cpu, cu111]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
- name: Upgrade pip
run: python -m pip install pip --upgrade --user
- name: Install OpenCV
run: pip install opencv-python>=3
- name: Install PyTorch
# As a complement to Linux CI, we test on PyTorch LTS version
run: pip install torch==1.8.2+${{ matrix.platform }} torchvision==0.9.2+${{ matrix.platform }} -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
- name: Install MMCV
run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8/index.html --only-binary mmcv-full
- name: Install unittest dependencies
run: pip install -r requirements/tests.txt -r requirements/optional.txt
- name: Build and install
run: pip install -e .
- name: Run unittests
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
- name: Generate coverage report
run: |
coverage xml
coverage report -m

파일 보기

@ -1,26 +0,0 @@
name: deploy
on: push
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Build MMSegmentation
run: |
pip install wheel
python setup.py sdist bdist_wheel
- name: Publish distribution to PyPI
run: |
pip install twine
twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}

파일 보기

@ -1,28 +0,0 @@
name: lint
on: [push, pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
lint:
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install pre-commit hook
run: |
pip install pre-commit
pre-commit install
- name: Linting
run: |
pre-commit run --all-files
- name: Check docstring coverage
run: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmseg/ops --ignore-regex "__repr__" --fail-under 80 mmseg

파일 보기

@ -1,44 +0,0 @@
name: test-mim
on:
push:
paths:
- 'model-index.yml'
- 'configs/**'
pull_request:
paths:
- 'model-index.yml'
- 'configs/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build_cpu:
runs-on: ubuntu-18.04
strategy:
matrix:
python-version: [3.7]
torch: [1.8.0]
include:
- torch: 1.8.0
torch_version: torch1.8
torchvision: 0.9.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install openmim
run: pip install openmim
- name: Build and install
run: rm -rf .eggs && mim install -e .
- name: test commands of mim
run: mim search mmsegmentation

파일 보기

@ -1,120 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/en/_build/
docs/zh_cn/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
.DS_Store
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
data
.vscode
.idea
# custom
*.pkl
*.pkl.json
*.log.json
work_dirs/
mmseg/.mim
# Pytorch
*.pth

파일 보기

@ -1,11 +0,0 @@
assign:
strategy:
# random
# round-robin
daily-shift-based
assignees:
- MengzhangLI
- xiexinch
- MeowZheng
- MengzhangLI
- xiexinch

파일 보기

@ -1,60 +0,0 @@
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.9
hooks:
- id: mdformat
args: ["--number"]
additional_dependencies:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: local
hooks:
- id: update-model-index
name: update-model-index
description: Collect model information and update model-index.yml
entry: .dev/md2yml.py
additional_dependencies: [mmcv, lxml, opencv-python]
language: python
files: ^configs/.*\.md$
require_serial: true
- repo: https://github.com/open-mmlab/pre-commit-hooks
rev: v0.2.0 # Use the rev to fix revision
hooks:
- id: check-algo-readme
- id: check-copyright
args: ["mmseg", "tools", "tests", "demo"] # the dir_to_check with expected directory to check

파일 보기

@ -1,9 +0,0 @@
version: 2
formats: all
python:
version: 3.7
install:
- requirements: requirements/docs.txt
- requirements: requirements/readthedocs.txt

파일 보기

@ -1,8 +0,0 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- name: "MMSegmentation Contributors"
title: "OpenMMLab Semantic Segmentation Toolbox and Benchmark"
date-released: 2020-07-10
url: "https://github.com/open-mmlab/mmsegmentation"
license: Apache-2.0

파일 보기

@ -1,203 +0,0 @@
Copyright 2020 The MMSegmentation Authors. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 The MMSegmentation Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

파일 보기

@ -1,7 +0,0 @@
# Licenses for special features
In this file, we list the features with other licenses instead of Apache 2.0. Users should be careful about adopting these features in any commercial matters.
| Feature | Files | License |
| :-------: | :-------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------: |
| SegFormer | [mmseg/models/decode_heads/segformer_head.py](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/segformer_head.py) | [NVIDIA License](https://github.com/NVlabs/SegFormer#license) |

파일 보기

@ -1,4 +0,0 @@
include requirements/*.txt
include mmseg/.mim/model-index.yml
recursive-include mmseg/.mim/configs *.py *.yml
recursive-include mmseg/.mim/tools *.py *.sh

파일 보기

@ -1,229 +0,0 @@
<div align="center">
<img src="resources/mmseg-logo.png" width="600"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab website</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
&nbsp;&nbsp;&nbsp;&nbsp;
<b><font size="5">OpenMMLab platform</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div>&nbsp;</div>
<br />
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmsegmentation)](https://pypi.org/project/mmsegmentation/)
[![PyPI](https://img.shields.io/pypi/v/mmsegmentation)](https://pypi.org/project/mmsegmentation)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmsegmentation.readthedocs.io/en/latest/)
[![badge](https://github.com/open-mmlab/mmsegmentation/workflows/build/badge.svg)](https://github.com/open-mmlab/mmsegmentation/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmsegmentation/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmsegmentation)
[![license](https://img.shields.io/github/license/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/issues)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/issues)
[📘Documentation](https://mmsegmentation.readthedocs.io/en/latest/) |
[🛠Installation](https://mmsegmentation.readthedocs.io/en/latest/get_started.html) |
[👀Model Zoo](https://mmsegmentation.readthedocs.io/en/latest/model_zoo.html) |
[🆕Update News](https://mmsegmentation.readthedocs.io/en/latest/changelog.html) |
[🤔Reporting Issues](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
</div>
<div align="center">
English | [简体中文](README_zh-CN.md)
</div>
## Introduction
MMSegmentation is an open source semantic segmentation toolbox based on PyTorch.
It is a part of the [OpenMMLab](https://openmmlab.com/) project.
The master branch works with **PyTorch 1.5+**.
![demo image](resources/seg_demo.gif)
<details open>
<summary>Major features</summary>
- **Unified Benchmark**
We provide a unified benchmark toolbox for various semantic segmentation methods.
- **Modular Design**
We decompose the semantic segmentation framework into different components and one can easily construct a customized semantic segmentation framework by combining different modules.
- **Support of multiple methods out of box**
The toolbox directly supports popular and contemporary semantic segmentation frameworks, *e.g.* PSPNet, DeepLabV3, PSANet, DeepLabV3+, etc.
- **High efficiency**
The training speed is faster than or comparable to other codebases.
</details>
## What's New
v0.25.0 was released in 6/2/2022:
- Support PyTorch backend on MLU
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
## Installation
Please refer to [get_started.md](docs/en/get_started.md#installation) for installation and [dataset_prepare.md](docs/en/dataset_prepare.md#prepare-datasets) for dataset preparation.
## Get Started
Please see [train.md](docs/en/train.md) and [inference.md](docs/en/inference.md) for the basic usage of MMSegmentation.
There are also tutorials for:
- [customizing dataset](docs/en/tutorials/customize_datasets.md)
- [designing data pipeline](docs/en/tutorials/data_pipeline.md)
- [customizing modules](docs/en/tutorials/customize_models.md)
- [customizing runtime](docs/en/tutorials/customize_runtime.md)
- [training tricks](docs/en/tutorials/training_tricks.md)
- [useful tools](docs/en/useful_tools.md)
A Colab tutorial is also provided. You may preview the notebook [here](demo/MMSegmentation_Tutorial.ipynb) or directly [run](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/master/demo/MMSegmentation_Tutorial.ipynb) on Colab.
## Benchmark and model zoo
Results and models are available in the [model zoo](docs/en/model_zoo.md).
Supported backbones:
- [x] ResNet (CVPR'2016)
- [x] ResNeXt (CVPR'2017)
- [x] [HRNet (CVPR'2019)](configs/hrnet)
- [x] [ResNeSt (ArXiv'2020)](configs/resnest)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
Supported methods:
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
- [x] [BiSeNetV1 (ECCV'2018)](configs/bisenetv1)
- [x] [PSANet (ECCV'2018)](configs/psanet)
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
- [x] [UPerNet (ECCV'2018)](configs/upernet)
- [x] [ICNet (ECCV'2018)](configs/icnet)
- [x] [NonLocal Net (CVPR'2018)](configs/nonlocal_net)
- [x] [EncNet (CVPR'2018)](configs/encnet)
- [x] [Semantic FPN (CVPR'2019)](configs/sem_fpn)
- [x] [DANet (CVPR'2019)](configs/danet)
- [x] [APCNet (CVPR'2019)](configs/apcnet)
- [x] [EMANet (ICCV'2019)](configs/emanet)
- [x] [CCNet (ICCV'2019)](configs/ccnet)
- [x] [DMNet (ICCV'2019)](configs/dmnet)
- [x] [ANN (ICCV'2019)](configs/ann)
- [x] [GCNet (ICCVW'2019/TPAMI'2020)](configs/gcnet)
- [x] [FastFCN (ArXiv'2019)](configs/fastfcn)
- [x] [Fast-SCNN (ArXiv'2019)](configs/fastscnn)
- [x] [ISANet (ArXiv'2019/IJCV'2021)](configs/isanet)
- [x] [OCRNet (ECCV'2020)](configs/ocrnet)
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
Supported datasets:
- [x] [Cityscapes](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#cityscapes)
- [x] [PASCAL VOC](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-voc)
- [x] [ADE20K](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#ade20k)
- [x] [Pascal Context](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#pascal-context)
- [x] [COCO-Stuff 10k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-10k)
- [x] [COCO-Stuff 164k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#coco-stuff-164k)
- [x] [CHASE_DB1](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#chase-db1)
- [x] [DRIVE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#drive)
- [x] [HRF](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#hrf)
- [x] [STARE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#stare)
- [x] [Dark Zurich](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#dark-zurich)
- [x] [Nighttime Driving](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#nighttime-driving)
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isaid)
## FAQ
Please refer to [FAQ](docs/en/faq.md) for frequently asked questions.
## Contributing
We appreciate all contributions to improve MMSegmentation. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
## Acknowledgement
MMSegmentation is an open source project that welcome any contribution and feedback.
We wish that the toolbox and benchmark could serve the growing research
community by providing a flexible as well as standardized toolkit to reimplement existing methods
and develop their own new semantic segmentation methods.
## Citation
If you find this project useful in your research, please consider cite:
```bibtex
@misc{mmseg2020,
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
author={MMSegmentation Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
year={2020}
}
```
## License
MMSegmentation is released under the Apache 2.0 license, while some specific features in this library are with other licenses. Please refer to [LICENSES.md](LICENSES.md) for the careful check, if you are using our code for commercial matters.
## Projects in OpenMMLab
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision.
- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab Model Deployment Framework.

파일 보기

@ -1,242 +0,0 @@
<div align="center">
<img src="resources/mmseg-logo.png" width="600"/>
<div>&nbsp;</div>
<div align="center">
<b><font size="5">OpenMMLab 官网</font></b>
<sup>
<a href="https://openmmlab.com">
<i><font size="4">HOT</font></i>
</a>
</sup>
&nbsp;&nbsp;&nbsp;&nbsp;
<b><font size="5">OpenMMLab 开放平台</font></b>
<sup>
<a href="https://platform.openmmlab.com">
<i><font size="4">TRY IT OUT</font></i>
</a>
</sup>
</div>
<div>&nbsp;</div>
<br />
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmsegmentation)](https://pypi.org/project/mmsegmentation/)
[![PyPI](https://img.shields.io/pypi/v/mmsegmentation)](https://pypi.org/project/mmsegmentation)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmsegmentation.readthedocs.io/zh_CN/latest/)
[![badge](https://github.com/open-mmlab/mmsegmentation/workflows/build/badge.svg)](https://github.com/open-mmlab/mmsegmentation/actions)
[![codecov](https://codecov.io/gh/open-mmlab/mmsegmentation/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmsegmentation)
[![license](https://img.shields.io/github/license/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/blob/master/LICENSE)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/issues)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmsegmentation.svg)](https://github.com/open-mmlab/mmsegmentation/issues)
[📘使用文档](https://mmsegmentation.readthedocs.io/en/latest/) |
[🛠️安装指南](https://mmsegmentation.readthedocs.io/en/latest/get_started.html) |
[👀模型库](https://mmsegmentation.readthedocs.io/en/latest/model_zoo.html) |
[🆕更新日志](https://mmsegmentation.readthedocs.io/en/latest/changelog.html) |
[🤔报告问题](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
[English](README.md) | 简体中文
</div>
## 简介
MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 OpenMMLab 项目的一部分。
主分支代码目前支持 PyTorch 1.5 以上的版本。
![示例图片](resources/seg_demo.gif)
<details open>
<summary>Major features</summary>
### 主要特性
- **统一的基准平台**
我们将各种各样的语义分割算法集成到了一个统一的工具箱,进行基准测试。
- **模块化设计**
MMSegmentation 将分割框架解耦成不同的模块组件,通过组合不同的模块组件,用户可以便捷地构建自定义的分割模型。
- **丰富的即插即用的算法和模型**
MMSegmentation 支持了众多主流的和最新的检测算法,例如 PSPNetDeepLabV3PSANetDeepLabV3+ 等.
- **速度快**
训练速度比其他语义分割代码库更快或者相当。
</details>
## 最新进展
最新版本 v0.25.0 在 2022.6.2 发布:
- 支持 PyTorch MLU 后端
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。
## 安装
请参考[快速入门文档](docs/zh_cn/get_started.md#installation)进行安装,参考[数据集准备](docs/zh_cn/dataset_prepare.md)处理数据。
## 快速入门
请参考[训练教程](docs/zh_cn/train.md)和[测试教程](docs/zh_cn/inference.md)学习 MMSegmentation 的基本使用。
我们也提供了一些进阶教程,内容覆盖了:
- [增加自定义数据集](docs/zh_cn/tutorials/customize_datasets.md)
- [设计新的数据预处理流程](docs/zh_cn/tutorials/data_pipeline.md)
- [增加自定义模型](docs/zh_cn/tutorials/customize_models.md)
- [增加自定义的运行时配置](docs/zh_cn/tutorials/customize_runtime.md)。
- [训练技巧说明](docs/zh_cn/tutorials/training_tricks.md)
- [有用的工具](docs/zh_cn/useful_tools.md)。
同时,我们提供了 Colab 教程。你可以在[这里](demo/MMSegmentation_Tutorial.ipynb)浏览教程,或者直接在 Colab 上[运行](https://colab.research.google.com/github/open-mmlab/mmsegmentation/blob/master/demo/MMSegmentation_Tutorial.ipynb)。
## 基准测试和模型库
测试结果和模型可以在[模型库](docs/zh_cn/model_zoo.md)中找到。
已支持的骨干网络:
- [x] ResNet (CVPR'2016)
- [x] ResNeXt (CVPR'2017)
- [x] [HRNet (CVPR'2019)](configs/hrnet)
- [x] [ResNeSt (ArXiv'2020)](configs/resnest)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3)
- [x] [Vision Transformer (ICLR'2021)](configs/vit)
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
已支持的算法:
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
- [x] [BiSeNetV1 (ECCV'2018)](configs/bisenetv1)
- [x] [PSANet (ECCV'2018)](configs/psanet)
- [x] [DeepLabV3+ (CVPR'2018)](configs/deeplabv3plus)
- [x] [UPerNet (ECCV'2018)](configs/upernet)
- [x] [ICNet (ECCV'2018)](configs/icnet)
- [x] [NonLocal Net (CVPR'2018)](configs/nonlocal_net)
- [x] [EncNet (CVPR'2018)](configs/encnet)
- [x] [Semantic FPN (CVPR'2019)](configs/sem_fpn)
- [x] [DANet (CVPR'2019)](configs/danet)
- [x] [APCNet (CVPR'2019)](configs/apcnet)
- [x] [EMANet (ICCV'2019)](configs/emanet)
- [x] [CCNet (ICCV'2019)](configs/ccnet)
- [x] [DMNet (ICCV'2019)](configs/dmnet)
- [x] [ANN (ICCV'2019)](configs/ann)
- [x] [GCNet (ICCVW'2019/TPAMI'2020)](configs/gcnet)
- [x] [FastFCN (ArXiv'2019)](configs/fastfcn)
- [x] [Fast-SCNN (ArXiv'2019)](configs/fastscnn)
- [x] [ISANet (ArXiv'2019/IJCV'2021)](configs/isanet)
- [x] [OCRNet (ECCV'2020)](configs/ocrnet)
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [STDC (CVPR'2021)](configs/stdc)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)
已支持的数据集:
- [x] [Cityscapes](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#cityscapes)
- [x] [PASCAL VOC](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#pascal-voc)
- [x] [ADE20K](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#ade20k)
- [x] [Pascal Context](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#pascal-context)
- [x] [COCO-Stuff 10k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#coco-stuff-10k)
- [x] [COCO-Stuff 164k](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#coco-stuff-164k)
- [x] [CHASE_DB1](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#chase-db1)
- [x] [DRIVE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#drive)
- [x] [HRF](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#hrf)
- [x] [STARE](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#stare)
- [x] [Dark Zurich](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#dark-zurich)
- [x] [Nighttime Driving](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#nighttime-driving)
- [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda)
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isaid)
## 常见问题
如果遇到问题,请参考 [常见问题解答](docs/zh_cn/faq.md)。
## 贡献指南
我们感谢所有的贡献者为改进和提升 MMSegmentation 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
## 致谢
MMSegmentation 是一个由来自不同高校和企业的研发人员共同参与贡献的开源项目。我们感谢所有为项目提供算法复现和新功能支持的贡献者,以及提供宝贵反馈的用户。 我们希望这个工具箱和基准测试可以为社区提供灵活的代码工具,供用户复现已有算法并开发自己的新模型,从而不断为开源社区提供贡献。
## 引用
如果你觉得本项目对你的研究工作有所帮助,请参考如下 bibtex 引用 MMSegmentation。
```bibtex
@misc{mmseg2020,
title={{MMSegmentation}: OpenMMLab Semantic Segmentation Toolbox and Benchmark},
author={MMSegmentation Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmsegmentation}},
year={2020}
}
```
## 开源许可证
`MMSegmentation` 目前以 Apache 2.0 的许可证发布,但是其中有一部分功能并不是使用的 Apache2.0 许可证,我们在 [许可证](LICENSES.md) 中详细地列出了这些功能以及他们对应的许可证,如果您正在从事盈利性活动,请谨慎参考此文档。
## OpenMMLab 的其他项目
- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库
- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
## 欢迎加入 OpenMMLab 社区
扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 [OpenMMLab 团队](https://jq.qq.com/?_wv=1027&k=aCvMxdr3) 以及 [MMSegmentation](https://jq.qq.com/?_wv=1027&k=9sprS2YO) 的 QQ 群。
<div align="center">
<img src="docs/zh_cn/imgs/zhihu_qrcode.jpg" height="400" /> <img src="docs/zh_cn/imgs/qq_group_qrcode.jpg" height="400" />
</div>
我们会在 OpenMMLab 社区为大家
- 📢 分享 AI 框架的前沿核心技术
- 💻 解读 PyTorch 常用模块源码
- 📰 发布 OpenMMLab 的相关新闻
- 🚀 介绍 OpenMMLab 开发的前沿算法
- 🏃 获取更高效的问题答疑和意见反馈
- 🔥 提供与各行各业开发者充分交流的平台
干货满满 📘,等你来撩 💗OpenMMLab 社区期待您的加入 👬

파일 보기

@ -1,53 +0,0 @@
# dataset settings
dataset_type = 'CustomDataset' # need to change
data_root = 'data/my_dataset_v7' # need to change
img_norm_cfg = dict(
mean=[127.93135507, 116.76565979, 103.67335042], std=[49.55883976, 47.7692082, 50.7934459], to_rgb=True) # need to calculate
crop_size = (512, 512) # need to change
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512), # need to change
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2, # need to change
workers_per_gpu=1, # need to change
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train', # need to change
ann_dir='ann_dir/train', # need to change
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',# need to change
ann_dir='ann_dir/val',# need to change
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/test',# need to change
ann_dir='ann_dir/test',# need to change
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,59 +0,0 @@
# dataset settings
dataset_type = 'ChaseDB1Dataset'
data_root = 'data/CHASE_DB1'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (960, 999)
crop_size = (128, 128)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/train',
ann_dir='gtFine/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='leftImg8bit/val',
ann_dir='gtFine/val',
pipeline=test_pipeline))

파일 보기

@ -1,35 +0,0 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

파일 보기

@ -1,35 +0,0 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (768, 768)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2049, 1025),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

파일 보기

@ -1,35 +0,0 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (769, 769)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2049, 1025),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

파일 보기

@ -1,35 +0,0 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (832, 832)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

파일 보기

@ -1,57 +0,0 @@
# dataset settings
dataset_type = 'COCOStuffDataset'
data_root = 'data/coco_stuff10k'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/train2014',
ann_dir='annotations/train2014',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/test2014',
ann_dir='annotations/test2014',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
img_dir='images/test2014',
ann_dir='annotations/test2014',
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'COCOStuffDataset'
data_root = 'data/coco_stuff164k'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/train2017',
ann_dir='annotations/train2017',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/val2017',
ann_dir='annotations/val2017',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/val2017',
ann_dir='annotations/val2017',
pipeline=test_pipeline))

파일 보기

@ -1,59 +0,0 @@
# dataset settings
dataset_type = 'DRIVEDataset'
data_root = 'data/DRIVE'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (584, 565)
crop_size = (64, 64)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,59 +0,0 @@
# dataset settings
dataset_type = 'HRFDataset'
data_root = 'data/HRF'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (2336, 3504)
crop_size = (256, 256)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,62 +0,0 @@
# dataset settings
dataset_type = 'iSAIDDataset'
data_root = 'data/iSAID'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
"""
This crop_size setting is followed by the implementation of
`PointFlow: Flowing Semantics Through Points for Aerial Image
Segmentation <https://arxiv.org/pdf/2103.06564.pdf>`_.
"""
crop_size = (896, 896)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(896, 896),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'LoveDADataset'
data_root = 'data/loveDA'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1024, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

파일 보기

@ -1,60 +0,0 @@
# dataset settings
dataset_type = 'PascalContextDataset'
data_root = 'data/VOCdevkit/VOC2010/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (520, 520)
crop_size = (480, 480)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))

파일 보기

@ -1,60 +0,0 @@
# dataset settings
dataset_type = 'PascalContextDataset59'
data_root = 'data/VOCdevkit/VOC2010/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (520, 520)
crop_size = (480, 480)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClassContext',
split='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))

파일 보기

@ -1,57 +0,0 @@
# dataset settings
dataset_type = 'PascalVOCDataset'
data_root = 'data/VOCdevkit/VOC2012'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/train.txt',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/val.txt',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='JPEGImages',
ann_dir='SegmentationClass',
split='ImageSets/Segmentation/val.txt',
pipeline=test_pipeline))

파일 보기

@ -1,9 +0,0 @@
_base_ = './pascal_voc12.py'
# dataset settings
data = dict(
train=dict(
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
split=[
'ImageSets/Segmentation/train.txt',
'ImageSets/Segmentation/aug.txt'
]))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'PotsdamDataset'
data_root = 'data/potsdam'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

파일 보기

@ -1,59 +0,0 @@
# dataset settings
dataset_type = 'STAREDataset'
data_root = 'data/STARE'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (605, 700)
crop_size = (128, 128)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

파일 보기

@ -1,54 +0,0 @@
# dataset settings
dataset_type = 'ISPRSDataset'
data_root = 'data/vaihingen'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/train',
ann_dir='ann_dir/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='img_dir/val',
ann_dir='ann_dir/val',
pipeline=test_pipeline))

파일 보기

@ -1,9 +0,0 @@
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
#load_from = 'checkpoints/danet_r50-d8_512x512_80k_ade20k_20200615_015125-edb18e08.pth'
load_from = None
resume_from = None
#workflow = [('train', 1)]
workflow = [('train', 1), ('val', 1)]
cudnn_benchmark = True

파일 보기

@ -1,46 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ANNHead',
in_channels=[1024, 2048],
in_index=[2, 3],
channels=512,
project_channels=256,
query_scales=(1, ),
key_pool_scales=(1, 3, 6, 8),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,44 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='APCHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,68 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='BiSeNetV1',
in_channels=3,
context_channels=(128, 256, 512),
spatial_channels=(64, 64, 64, 128),
out_indices=(0, 1, 2),
out_channels=256,
backbone_cfg=dict(
type='ResNet',
in_channels=3,
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=None),
decode_head=dict(
type='FCNHead',
in_channels=256,
in_index=0,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,80 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='BiSeNetV2',
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
init_cfg=None,
align_corners=False),
decode_head=dict(
type='FCNHead',
in_channels=128,
in_index=0,
channels=1024,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=16,
channels=16,
num_convs=2,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=32,
channels=64,
num_convs=2,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=64,
channels=256,
num_convs=2,
num_classes=19,
in_index=3,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=1024,
num_convs=2,
num_classes=19,
in_index=4,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,44 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='CCHead',
in_channels=2048,
in_index=3,
channels=512,
recurrence=2,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,35 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='CGNet',
norm_cfg=norm_cfg,
in_channels=3,
num_channels=(32, 64, 128),
num_blocks=(3, 21),
dilations=(2, 4),
reductions=(8, 16)),
decode_head=dict(
type='FCNHead',
in_channels=256,
in_index=2,
channels=256,
num_convs=0,
concat_input=False,
dropout_ratio=0,
num_classes=19,
norm_cfg=norm_cfg,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
class_weight=[
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
10.396974, 10.055647
])),
# model training and testing settings
train_cfg=dict(sampler=None),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,44 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(
type='ResNetV1c',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DAHead',
in_channels=2048,
in_index=3,
channels=512,
pam_channels=64,
dropout_ratio=0.1,
num_classes=5, # Need to change
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=5, # Need to change
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,44 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,50 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='UNet',
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
with_cp=False,
conv_cfg=None,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
upsample_cfg=dict(type='InterpConv'),
norm_eval=False),
decode_head=dict(
type='ASPPHead',
in_channels=64,
in_index=4,
channels=16,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=64,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=2,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide', crop_size=256, stride=170))

파일 보기

@ -1,46 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

파일 보기

@ -1,44 +0,0 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DMHead',
in_channels=2048,
in_index=3,
channels=512,
filter_sizes=(1, 3, 5, 7),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

Some files were not shown because too many files have changed in this diff Show More