2026-03-18 12:44:44 +01:00

170 lines
6.3 KiB
Python

"""
Fine-tune EfficientNet V2-S on the Wild Forest Animals dataset.
Produces: efficientnet_v2_wild_forest_animals.pt
Usage: uv run python train.py
uv run python train.py --epochs 5 --lr 0.0005 --batch-size 16
"""
import argparse
import csv
import random
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from PIL import Image
SEED = 42
CLASS_NAMES = ["bear", "deer", "fox", "hare", "moose", "person", "wolf"]
DATASET_DIR = Path("wild-forest-animals-and-person-1")
WEIGHTS_PATH = Path("efficientnet_v2_wild_forest_animals.pt")
def seed_everything(seed: int = SEED):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def ensure_dataset():
if DATASET_DIR.exists():
return
print("Dataset not found -- downloading from Roboflow ...")
from roboflow import Roboflow
rf = Roboflow(api_key="VCZWezdoCHQz7juipBdt")
project = rf.workspace("forestanimals").project("wild-forest-animals-and-person")
project.version(1).download("multiclass")
if not DATASET_DIR.exists():
raise RuntimeError(f"Download finished but {DATASET_DIR} not found.")
class WildForestAnimalsDataset(Dataset):
def __init__(self, root: Path, split: str, transform=None):
self.root, self.transform = Path(root), transform
split_dir = self.root / split
csv_path = split_dir / "_classes.csv"
if not csv_path.exists():
raise FileNotFoundError(f"CSV not found for split '{split}' at {csv_path}")
self.samples: list[tuple[Path, int]] = []
with csv_path.open(newline="") as f:
for row in csv.DictReader(f):
oh = [int(row[n]) for n in CLASS_NAMES]
try:
label = oh.index(1)
except ValueError:
continue
self.samples.append((split_dir / row["filename"], label))
if not self.samples:
raise RuntimeError(f"No samples loaded for split '{split}'")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
return self.transform(img) if self.transform else img, label
def build_transforms(train: bool):
common = [
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
if train:
return transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)] + common)
return transforms.Compose(common)
def create_model(num_classes: int) -> nn.Module:
model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
return model
def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
model.train()
total_loss = 0.0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with autocast(device.type, enabled=(device.type == "cuda")):
loss = criterion(model(images), labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item() * images.size(0)
return total_loss / len(loader.dataset)
def evaluate(model, loader, criterion, device):
model.eval()
total_loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
with autocast(device.type, enabled=(device.type == "cuda")):
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return total_loss / len(loader.dataset), 100.0 * correct / total
def main():
parser = argparse.ArgumentParser(description="Fine-tune EfficientNet V2-S")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--output", type=Path, default=WEIGHTS_PATH)
args = parser.parse_args()
seed_everything()
ensure_dataset()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
train_ds = WildForestAnimalsDataset(DATASET_DIR, "train", build_transforms(train=True))
val_ds = WildForestAnimalsDataset(DATASET_DIR, "valid", build_transforms(train=False))
test_ds = WildForestAnimalsDataset(DATASET_DIR, "test", build_transforms(train=False))
print(f"Splits: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test")
gen = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, generator=gen)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
model = create_model(len(CLASS_NAMES)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scaler = GradScaler(device.type, enabled=(device.type == "cuda"))
for epoch in range(1, args.epochs + 1):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(f"Epoch {epoch}/{args.epochs} -- "
f"train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.2f}%")
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.2f}%")
torch.save(model.state_dict(), args.output)
print(f"Saved weights to {args.output.resolve()}")
if __name__ == "__main__":
main()