Add model checkpoint callback

This commit is contained in:
mbsantiago 2025-06-28 15:46:50 -06:00
parent 8368aad178
commit 6732394f50
4 changed files with 12 additions and 4006 deletions

View File

@ -118,7 +118,12 @@ train:
weight: 0.1
logger:
logger_type: mlflow
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
save_dir: outputs/log/
artifact_location: outputs/artifacts/
checkpoint_path_prefix: outputs/checkpoints/
augmentations:
steps:
- augmentation_type: mix_audio

View File

@ -21,7 +21,6 @@ dependencies = [
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import yaml
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
@ -89,6 +89,11 @@ def train(
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
save_top_k=1,
monitor="total_loss/val",
),
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
@ -97,7 +102,7 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
),
]

4003
uv.lock generated

File diff suppressed because it is too large Load Diff