mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Moved previous code to legacy folders
This commit is contained in:
parent
451093f2da
commit
d9f7304a0f
@ -12,8 +12,8 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
import batdetect2.train.evaluate as evl
|
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.legacy.train_utils as tu
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
@ -8,10 +8,10 @@ import torch.utils.data
|
|||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
import batdetect2.detector.parameters as parameters
|
import batdetect2.detector.parameters as parameters
|
||||||
import batdetect2.train.audio_dataloader as adl
|
import batdetect2.train.legacy.audio_dataloader as adl
|
||||||
|
import batdetect2.train.legacy.train_model as tm
|
||||||
|
import batdetect2.train.legacy.train_utils as tu
|
||||||
import batdetect2.train.losses as losses
|
import batdetect2.train.losses as losses
|
||||||
import batdetect2.train.train_model as tm
|
|
||||||
import batdetect2.train.train_utils as tu
|
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
|
82
batdetect2/train/legacy/train.py
Normal file
82
batdetect2/train/legacy/train.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.models.typing import DetectionModel
|
||||||
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
|
||||||
|
class TrainInputs(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_dataset: LabeledDataset[TrainInputs],
|
||||||
|
validation_dataset: LabeledDataset[TrainInputs],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_epochs: int = 100,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
):
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
num_epochs * len(train_loader),
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
train_loss = train_single_epoch(
|
||||||
|
model,
|
||||||
|
train_loader,
|
||||||
|
optimizer,
|
||||||
|
device,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_single_epoch(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
optimizer: Adam,
|
||||||
|
device: torch.device,
|
||||||
|
scheduler: CosineAnnealingLR,
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
train_loss = tu.AverageMeter()
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
spec = batch.spec.to(device)
|
||||||
|
detection_heatmap = batch.detection_heatmap.to(device)
|
||||||
|
class_heatmap = batch.class_heatmap.to(device)
|
||||||
|
size_heatmap = batch.size_heatmap.to(device)
|
||||||
|
|
||||||
|
outputs = model(spec)
|
||||||
|
|
||||||
|
loss = loss_fun(
|
||||||
|
outputs,
|
||||||
|
gt_det,
|
||||||
|
gt_size,
|
||||||
|
gt_class,
|
||||||
|
det_criterion,
|
||||||
|
params,
|
||||||
|
class_inv_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loss.update(loss.item(), data.shape[0])
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
Loading…
Reference in New Issue
Block a user