mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Moved previous code to legacy folders
This commit is contained in:
parent
451093f2da
commit
d9f7304a0f
@ -12,8 +12,8 @@ import pandas as pd
|
||||
import torch
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import parameters
|
@ -8,10 +8,10 @@ import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.parameters as parameters
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.train_model as tm
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.train.losses as losses
|
||||
import batdetect2.train.train_model as tm
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2 import types
|
||||
|
82
batdetect2/train/legacy/train.py
Normal file
82
batdetect2/train/legacy/train.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: LabeledDataset[TrainInputs],
|
||||
validation_dataset: LabeledDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
Loading…
Reference in New Issue
Block a user