add standalone train
This commit is contained in:
parent
8d02b1a9af
commit
3cb0d59bd6
21
README.md
21
README.md
@ -15,21 +15,25 @@ The dashboard simulates a live feed of camera-trap detections across five locati
|
|||||||
### 1. Clone the repository
|
### 1. Clone the repository
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone <repo-url>
|
git clone https://git.barrys.cloud/barry/Wildlife-Detection.git
|
||||||
cd assignment-HAI
|
cd Wildlife-Detection
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Obtain the model weights
|
### 2. Obtain the model weights
|
||||||
|
|
||||||
The trained model weights are not included in the repository due to their size. Place the file in the project root:
|
The trained model weights are not included in the repository due to their size. You can regenerate them by running the training script:
|
||||||
|
|
||||||
| File | Size | Description |
|
```bash
|
||||||
|---|---|---|
|
uv run python train.py
|
||||||
| `efficientnet_v2_wild_forest_animals.pt` | ~78 MB | Fine-tuned EfficientNet V2-S weights |
|
```
|
||||||
|
|
||||||
The model was fine-tuned in `final.ipynb`.
|
This downloads the dataset from Roboflow (if not present), fine-tunes EfficientNet V2-S for 3 epochs, and saves `efficientnet_v2_wild_forest_animals.pt`. Optional flags:
|
||||||
|
|
||||||
The **dataset** (`wild-forest-animals-and-person-1/`) is downloaded automatically from [Roboflow](https://roboflow.com/) on first launch if not already present on disk.
|
```bash
|
||||||
|
uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16
|
||||||
|
```
|
||||||
|
|
||||||
|
The **dataset** (`wild-forest-animals-and-person-1/`) is downloaded automatically from [Roboflow](https://roboflow.com/) on first run if not already present on disk.
|
||||||
|
|
||||||
### 3. Install dependencies
|
### 3. Install dependencies
|
||||||
|
|
||||||
@ -50,6 +54,7 @@ The server starts at **http://localhost:5000**.
|
|||||||
```
|
```
|
||||||
assignment-HAI/
|
assignment-HAI/
|
||||||
├── dashboard.py # Flask application (main entry point)
|
├── dashboard.py # Flask application (main entry point)
|
||||||
|
├── train.py # Standalone training script
|
||||||
├── final.ipynb # Training, evaluation, and XAI notebook
|
├── final.ipynb # Training, evaluation, and XAI notebook
|
||||||
├── map.webp # Park map background image
|
├── map.webp # Park map background image
|
||||||
├── pyproject.toml # Project metadata and dependencies
|
├── pyproject.toml # Project metadata and dependencies
|
||||||
|
|||||||
169
train.py
Normal file
169
train.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Fine-tune EfficientNet V2-S on the Wild Forest Animals dataset.
|
||||||
|
|
||||||
|
Produces: efficientnet_v2_wild_forest_animals.pt
|
||||||
|
|
||||||
|
Usage: uv run python train.py
|
||||||
|
uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.amp import GradScaler, autocast
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
CLASS_NAMES = ["bear", "deer", "fox", "hare", "moose", "person", "wolf"]
|
||||||
|
DATASET_DIR = Path("wild-forest-animals-and-person-1")
|
||||||
|
WEIGHTS_PATH = Path("efficientnet_v2_wild_forest_animals.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int = SEED):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dataset():
|
||||||
|
if DATASET_DIR.exists():
|
||||||
|
return
|
||||||
|
print("Dataset not found -- downloading from Roboflow ...")
|
||||||
|
from roboflow import Roboflow
|
||||||
|
rf = Roboflow(api_key="VCZWezdoCHQz7juipBdt")
|
||||||
|
project = rf.workspace("forestanimals").project("wild-forest-animals-and-person")
|
||||||
|
project.version(1).download("multiclass")
|
||||||
|
if not DATASET_DIR.exists():
|
||||||
|
raise RuntimeError(f"Download finished but {DATASET_DIR} not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class WildForestAnimalsDataset(Dataset):
|
||||||
|
def __init__(self, root: Path, split: str, transform=None):
|
||||||
|
self.root, self.transform = Path(root), transform
|
||||||
|
split_dir = self.root / split
|
||||||
|
csv_path = split_dir / "_classes.csv"
|
||||||
|
if not csv_path.exists():
|
||||||
|
raise FileNotFoundError(f"CSV not found for split '{split}' at {csv_path}")
|
||||||
|
self.samples: list[tuple[Path, int]] = []
|
||||||
|
with csv_path.open(newline="") as f:
|
||||||
|
for row in csv.DictReader(f):
|
||||||
|
oh = [int(row[n]) for n in CLASS_NAMES]
|
||||||
|
try:
|
||||||
|
label = oh.index(1)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
self.samples.append((split_dir / row["filename"], label))
|
||||||
|
if not self.samples:
|
||||||
|
raise RuntimeError(f"No samples loaded for split '{split}'")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
path, label = self.samples[idx]
|
||||||
|
img = Image.open(path).convert("RGB")
|
||||||
|
return self.transform(img) if self.transform else img, label
|
||||||
|
|
||||||
|
|
||||||
|
def build_transforms(train: bool):
|
||||||
|
common = [
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Resize((224, 224), antialias=True),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
if train:
|
||||||
|
return transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)] + common)
|
||||||
|
return transforms.Compose(common)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(num_classes: int) -> nn.Module:
|
||||||
|
model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
|
||||||
|
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
for images, labels in loader:
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
with autocast(device.type, enabled=(device.type == "cuda")):
|
||||||
|
loss = criterion(model(images), labels)
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
total_loss += loss.item() * images.size(0)
|
||||||
|
return total_loss / len(loader.dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, loader, criterion, device):
|
||||||
|
model.eval()
|
||||||
|
total_loss, correct, total = 0.0, 0, 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, labels in loader:
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
with autocast(device.type, enabled=(device.type == "cuda")):
|
||||||
|
outputs = model(images)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
total_loss += loss.item() * images.size(0)
|
||||||
|
correct += (outputs.argmax(1) == labels).sum().item()
|
||||||
|
total += labels.size(0)
|
||||||
|
return total_loss / len(loader.dataset), 100.0 * correct / total
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Fine-tune EfficientNet V2-S")
|
||||||
|
parser.add_argument("--epochs", type=int, default=3)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=32)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3)
|
||||||
|
parser.add_argument("--output", type=Path, default=WEIGHTS_PATH)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed_everything()
|
||||||
|
ensure_dataset()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
train_ds = WildForestAnimalsDataset(DATASET_DIR, "train", build_transforms(train=True))
|
||||||
|
val_ds = WildForestAnimalsDataset(DATASET_DIR, "valid", build_transforms(train=False))
|
||||||
|
test_ds = WildForestAnimalsDataset(DATASET_DIR, "test", build_transforms(train=False))
|
||||||
|
print(f"Splits: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test")
|
||||||
|
|
||||||
|
gen = torch.Generator().manual_seed(SEED)
|
||||||
|
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, generator=gen)
|
||||||
|
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
|
||||||
|
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
|
||||||
|
|
||||||
|
model = create_model(len(CLASS_NAMES)).to(device)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
scaler = GradScaler(device.type, enabled=(device.type == "cuda"))
|
||||||
|
|
||||||
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
|
||||||
|
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
|
||||||
|
print(f"Epoch {epoch}/{args.epochs} -- "
|
||||||
|
f"train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.2f}%")
|
||||||
|
|
||||||
|
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
|
||||||
|
print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.2f}%")
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), args.output)
|
||||||
|
print(f"Saved weights to {args.output.resolve()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user