mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add model checkpoint callback
This commit is contained in:
parent
8368aad178
commit
6732394f50
@ -118,7 +118,12 @@ train:
|
||||
weight: 0.1
|
||||
logger:
|
||||
logger_type: mlflow
|
||||
experiment_name: batdetect2
|
||||
tracking_uri: http://localhost:5000
|
||||
log_model: true
|
||||
save_dir: outputs/log/
|
||||
artifact_location: outputs/artifacts/
|
||||
checkpoint_path_prefix: outputs/checkpoints/
|
||||
augmentations:
|
||||
steps:
|
||||
- augmentation_type: mix_audio
|
||||
|
||||
@ -21,7 +21,6 @@ dependencies = [
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"pytorch-lightning>=2.2.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
@ -89,6 +89,11 @@ def train(
|
||||
|
||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath="outputs/checkpoints",
|
||||
save_top_k=1,
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
@ -97,7 +102,7 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user