From 3cb0d59bd6915ae9e43b14049c0177a0f5724ec3 Mon Sep 17 00:00:00 2001 From: everbarry Date: Wed, 18 Mar 2026 12:44:44 +0100 Subject: [PATCH] add standalone train --- README.md | 21 ++++--- train.py | 169 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 train.py diff --git a/README.md b/README.md index 33904dc..e3b8974 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,25 @@ The dashboard simulates a live feed of camera-trap detections across five locati ### 1. Clone the repository ```bash -git clone -cd assignment-HAI +git clone https://git.barrys.cloud/barry/Wildlife-Detection.git +cd Wildlife-Detection ``` ### 2. Obtain the model weights -The trained model weights are not included in the repository due to their size. Place the file in the project root: +The trained model weights are not included in the repository due to their size. You can regenerate them by running the training script: -| File | Size | Description | -|---|---|---| -| `efficientnet_v2_wild_forest_animals.pt` | ~78 MB | Fine-tuned EfficientNet V2-S weights | +```bash +uv run python train.py +``` -The model was fine-tuned in `final.ipynb`. +This downloads the dataset from Roboflow (if not present), fine-tunes EfficientNet V2-S for 3 epochs, and saves `efficientnet_v2_wild_forest_animals.pt`. Optional flags: -The **dataset** (`wild-forest-animals-and-person-1/`) is downloaded automatically from [Roboflow](https://roboflow.com/) on first launch if not already present on disk. +```bash +uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16 +``` + +The **dataset** (`wild-forest-animals-and-person-1/`) is downloaded automatically from [Roboflow](https://roboflow.com/) on first run if not already present on disk. ### 3. Install dependencies @@ -50,6 +54,7 @@ The server starts at **http://localhost:5000**. ``` assignment-HAI/ ├── dashboard.py # Flask application (main entry point) +├── train.py # Standalone training script ├── final.ipynb # Training, evaluation, and XAI notebook ├── map.webp # Park map background image ├── pyproject.toml # Project metadata and dependencies diff --git a/train.py b/train.py new file mode 100644 index 0000000..cec5c06 --- /dev/null +++ b/train.py @@ -0,0 +1,169 @@ +""" +Fine-tune EfficientNet V2-S on the Wild Forest Animals dataset. + +Produces: efficientnet_v2_wild_forest_animals.pt + +Usage: uv run python train.py + uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16 +""" + +import argparse +import csv +import random +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from torch.amp import GradScaler, autocast +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights +from PIL import Image + +SEED = 42 +CLASS_NAMES = ["bear", "deer", "fox", "hare", "moose", "person", "wolf"] +DATASET_DIR = Path("wild-forest-animals-and-person-1") +WEIGHTS_PATH = Path("efficientnet_v2_wild_forest_animals.pt") + + +def seed_everything(seed: int = SEED): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def ensure_dataset(): + if DATASET_DIR.exists(): + return + print("Dataset not found -- downloading from Roboflow ...") + from roboflow import Roboflow + rf = Roboflow(api_key="VCZWezdoCHQz7juipBdt") + project = rf.workspace("forestanimals").project("wild-forest-animals-and-person") + project.version(1).download("multiclass") + if not DATASET_DIR.exists(): + raise RuntimeError(f"Download finished but {DATASET_DIR} not found.") + + +class WildForestAnimalsDataset(Dataset): + def __init__(self, root: Path, split: str, transform=None): + self.root, self.transform = Path(root), transform + split_dir = self.root / split + csv_path = split_dir / "_classes.csv" + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found for split '{split}' at {csv_path}") + self.samples: list[tuple[Path, int]] = [] + with csv_path.open(newline="") as f: + for row in csv.DictReader(f): + oh = [int(row[n]) for n in CLASS_NAMES] + try: + label = oh.index(1) + except ValueError: + continue + self.samples.append((split_dir / row["filename"], label)) + if not self.samples: + raise RuntimeError(f"No samples loaded for split '{split}'") + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, label = self.samples[idx] + img = Image.open(path).convert("RGB") + return self.transform(img) if self.transform else img, label + + +def build_transforms(train: bool): + common = [ + transforms.ToTensor(), + transforms.Resize((224, 224), antialias=True), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + if train: + return transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)] + common) + return transforms.Compose(common) + + +def create_model(num_classes: int) -> nn.Module: + model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) + model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) + return model + + +def train_one_epoch(model, loader, criterion, optimizer, scaler, device): + model.train() + total_loss = 0.0 + for images, labels in loader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + with autocast(device.type, enabled=(device.type == "cuda")): + loss = criterion(model(images), labels) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + total_loss += loss.item() * images.size(0) + return total_loss / len(loader.dataset) + + +def evaluate(model, loader, criterion, device): + model.eval() + total_loss, correct, total = 0.0, 0, 0 + with torch.no_grad(): + for images, labels in loader: + images, labels = images.to(device), labels.to(device) + with autocast(device.type, enabled=(device.type == "cuda")): + outputs = model(images) + loss = criterion(outputs, labels) + total_loss += loss.item() * images.size(0) + correct += (outputs.argmax(1) == labels).sum().item() + total += labels.size(0) + return total_loss / len(loader.dataset), 100.0 * correct / total + + +def main(): + parser = argparse.ArgumentParser(description="Fine-tune EfficientNet V2-S") + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--output", type=Path, default=WEIGHTS_PATH) + args = parser.parse_args() + + seed_everything() + ensure_dataset() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + train_ds = WildForestAnimalsDataset(DATASET_DIR, "train", build_transforms(train=True)) + val_ds = WildForestAnimalsDataset(DATASET_DIR, "valid", build_transforms(train=False)) + test_ds = WildForestAnimalsDataset(DATASET_DIR, "test", build_transforms(train=False)) + print(f"Splits: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test") + + gen = torch.Generator().manual_seed(SEED) + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, generator=gen) + val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False) + test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) + + model = create_model(len(CLASS_NAMES)).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + scaler = GradScaler(device.type, enabled=(device.type == "cuda")) + + for epoch in range(1, args.epochs + 1): + train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device) + val_loss, val_acc = evaluate(model, val_loader, criterion, device) + print(f"Epoch {epoch}/{args.epochs} -- " + f"train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.2f}%") + + test_loss, test_acc = evaluate(model, test_loader, criterion, device) + print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.2f}%") + + torch.save(model.state_dict(), args.output) + print(f"Saved weights to {args.output.resolve()}") + + +if __name__ == "__main__": + main()