""" Fine-tune EfficientNet V2-S on the Wild Forest Animals dataset. Produces: efficientnet_v2_wild_forest_animals.pt Usage: uv run python train.py uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16 """ import argparse import csv import random from pathlib import Path import numpy as np import torch from torch import nn from torch.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights from PIL import Image SEED = 42 CLASS_NAMES = ["bear", "deer", "fox", "hare", "moose", "person", "wolf"] DATASET_DIR = Path("wild-forest-animals-and-person-1") WEIGHTS_PATH = Path("efficientnet_v2_wild_forest_animals.pt") def seed_everything(seed: int = SEED): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def ensure_dataset(): if DATASET_DIR.exists(): return print("Dataset not found -- downloading from Roboflow ...") from roboflow import Roboflow rf = Roboflow(api_key="VCZWezdoCHQz7juipBdt") project = rf.workspace("forestanimals").project("wild-forest-animals-and-person") project.version(1).download("multiclass") if not DATASET_DIR.exists(): raise RuntimeError(f"Download finished but {DATASET_DIR} not found.") class WildForestAnimalsDataset(Dataset): def __init__(self, root: Path, split: str, transform=None): self.root, self.transform = Path(root), transform split_dir = self.root / split csv_path = split_dir / "_classes.csv" if not csv_path.exists(): raise FileNotFoundError(f"CSV not found for split '{split}' at {csv_path}") self.samples: list[tuple[Path, int]] = [] with csv_path.open(newline="") as f: for row in csv.DictReader(f): oh = [int(row[n]) for n in CLASS_NAMES] try: label = oh.index(1) except ValueError: continue self.samples.append((split_dir / row["filename"], label)) if not self.samples: raise RuntimeError(f"No samples loaded for split '{split}'") def __len__(self): return len(self.samples) def __getitem__(self, idx): path, label = self.samples[idx] img = Image.open(path).convert("RGB") return self.transform(img) if self.transform else img, label def build_transforms(train: bool): common = [ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] if train: return transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)] + common) return transforms.Compose(common) def create_model(num_classes: int) -> nn.Module: model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) return model def train_one_epoch(model, loader, criterion, optimizer, scaler, device): model.train() total_loss = 0.0 for images, labels in loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() with autocast(device.type, enabled=(device.type == "cuda")): loss = criterion(model(images), labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() * images.size(0) return total_loss / len(loader.dataset) def evaluate(model, loader, criterion, device): model.eval() total_loss, correct, total = 0.0, 0, 0 with torch.no_grad(): for images, labels in loader: images, labels = images.to(device), labels.to(device) with autocast(device.type, enabled=(device.type == "cuda")): outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() * images.size(0) correct += (outputs.argmax(1) == labels).sum().item() total += labels.size(0) return total_loss / len(loader.dataset), 100.0 * correct / total def main(): parser = argparse.ArgumentParser(description="Fine-tune EfficientNet V2-S") parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--output", type=Path, default=WEIGHTS_PATH) args = parser.parse_args() seed_everything() ensure_dataset() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") train_ds = WildForestAnimalsDataset(DATASET_DIR, "train", build_transforms(train=True)) val_ds = WildForestAnimalsDataset(DATASET_DIR, "valid", build_transforms(train=False)) test_ds = WildForestAnimalsDataset(DATASET_DIR, "test", build_transforms(train=False)) print(f"Splits: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test") gen = torch.Generator().manual_seed(SEED) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, generator=gen) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) model = create_model(len(CLASS_NAMES)).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scaler = GradScaler(device.type, enabled=(device.type == "cuda")) for epoch in range(1, args.epochs + 1): train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device) val_loss, val_acc = evaluate(model, val_loader, criterion, device) print(f"Epoch {epoch}/{args.epochs} -- " f"train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.2f}%") test_loss, test_acc = evaluate(model, test_loader, criterion, device) print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.2f}%") torch.save(model.state_dict(), args.output) print(f"Saved weights to {args.output.resolve()}") if __name__ == "__main__": main()