diff --git a/batdetect2/train/evaluate.py b/batdetect2/evaluate/evaluate.py similarity index 100% rename from batdetect2/train/evaluate.py rename to batdetect2/evaluate/evaluate.py diff --git a/batdetect2/evaluate/evaluate_models.py b/batdetect2/evaluate/legacy/evaluate_models.py similarity index 99% rename from batdetect2/evaluate/evaluate_models.py rename to batdetect2/evaluate/legacy/evaluate_models.py index 602faa8..88cab61 100644 --- a/batdetect2/evaluate/evaluate_models.py +++ b/batdetect2/evaluate/legacy/evaluate_models.py @@ -12,8 +12,8 @@ import pandas as pd import torch from sklearn.ensemble import RandomForestClassifier -import batdetect2.train.evaluate as evl -import batdetect2.train.train_utils as tu +import batdetect2.evaluate.legacy.evaluate_models as evl +import batdetect2.train.legacy.train_utils as tu import batdetect2.utils.detector_utils as du import batdetect2.utils.plot_utils as pu from batdetect2.detector import parameters diff --git a/batdetect2/evaluate/readme.md b/batdetect2/evaluate/legacy/readme.md similarity index 100% rename from batdetect2/evaluate/readme.md rename to batdetect2/evaluate/legacy/readme.md diff --git a/batdetect2/finetune/finetune_model.py b/batdetect2/finetune/finetune_model.py index 4c2d1c2..3135d78 100644 --- a/batdetect2/finetune/finetune_model.py +++ b/batdetect2/finetune/finetune_model.py @@ -8,10 +8,10 @@ import torch.utils.data from torch.optim.lr_scheduler import CosineAnnealingLR import batdetect2.detector.parameters as parameters -import batdetect2.train.audio_dataloader as adl +import batdetect2.train.legacy.audio_dataloader as adl +import batdetect2.train.legacy.train_model as tm +import batdetect2.train.legacy.train_utils as tu import batdetect2.train.losses as losses -import batdetect2.train.train_model as tm -import batdetect2.train.train_utils as tu import batdetect2.utils.detector_utils as du import batdetect2.utils.plot_utils as pu from batdetect2 import types diff --git a/batdetect2/train/readme.md b/batdetect2/train/legacy/readme.md similarity index 100% rename from batdetect2/train/readme.md rename to batdetect2/train/legacy/readme.md diff --git a/batdetect2/train/legacy/train.py b/batdetect2/train/legacy/train.py new file mode 100644 index 0000000..a59b13a --- /dev/null +++ b/batdetect2/train/legacy/train.py @@ -0,0 +1,82 @@ +from typing import Callable, NamedTuple, Optional + +import torch +from soundevent import data +from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import DataLoader + +from batdetect2.models.typing import DetectionModel +from batdetect2.train.dataset import LabeledDataset + + +class TrainInputs(NamedTuple): + spec: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + + +def train_loop( + model: DetectionModel, + train_dataset: LabeledDataset[TrainInputs], + validation_dataset: LabeledDataset[TrainInputs], + device: Optional[torch.device] = None, + num_epochs: int = 100, + learning_rate: float = 1e-4, +): + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + validation_loader = DataLoader(validation_dataset, batch_size=32) + + model.to(device) + + optimizer = Adam(model.parameters(), lr=learning_rate) + scheduler = CosineAnnealingLR( + optimizer, + num_epochs * len(train_loader), + ) + + for epoch in range(num_epochs): + train_loss = train_single_epoch( + model, + train_loader, + optimizer, + device, + scheduler, + ) + + +def train_single_epoch( + model: DetectionModel, + train_loader: DataLoader, + optimizer: Adam, + device: torch.device, + scheduler: CosineAnnealingLR, +): + model.train() + train_loss = tu.AverageMeter() + + for batch in train_loader: + optimizer.zero_grad() + + spec = batch.spec.to(device) + detection_heatmap = batch.detection_heatmap.to(device) + class_heatmap = batch.class_heatmap.to(device) + size_heatmap = batch.size_heatmap.to(device) + + outputs = model(spec) + + loss = loss_fun( + outputs, + gt_det, + gt_size, + gt_class, + det_criterion, + params, + class_inv_freq, + ) + + train_loss.update(loss.item(), data.shape[0]) + loss.backward() + optimizer.step() + scheduler.step() diff --git a/batdetect2/train/train_model.py b/batdetect2/train/legacy/train_model.py similarity index 100% rename from batdetect2/train/train_model.py rename to batdetect2/train/legacy/train_model.py diff --git a/batdetect2/train/train_split.py b/batdetect2/train/legacy/train_split.py similarity index 100% rename from batdetect2/train/train_split.py rename to batdetect2/train/legacy/train_split.py diff --git a/batdetect2/train/train_utils.py b/batdetect2/train/legacy/train_utils.py similarity index 100% rename from batdetect2/train/train_utils.py rename to batdetect2/train/legacy/train_utils.py