mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Add model checkpoint callback
This commit is contained in:
parent
8368aad178
commit
6732394f50
@ -118,7 +118,12 @@ train:
|
|||||||
weight: 0.1
|
weight: 0.1
|
||||||
logger:
|
logger:
|
||||||
logger_type: mlflow
|
logger_type: mlflow
|
||||||
|
experiment_name: batdetect2
|
||||||
tracking_uri: http://localhost:5000
|
tracking_uri: http://localhost:5000
|
||||||
|
log_model: true
|
||||||
|
save_dir: outputs/log/
|
||||||
|
artifact_location: outputs/artifacts/
|
||||||
|
checkpoint_path_prefix: outputs/checkpoints/
|
||||||
augmentations:
|
augmentations:
|
||||||
steps:
|
steps:
|
||||||
- augmentation_type: mix_audio
|
- augmentation_type: mix_audio
|
||||||
|
|||||||
@ -21,7 +21,6 @@ dependencies = [
|
|||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
"pytorch-lightning>=2.2.2",
|
|
||||||
"cf-xarray>=0.9.0",
|
"cf-xarray>=0.9.0",
|
||||||
"onnx>=1.16.0",
|
"onnx>=1.16.0",
|
||||||
"lightning[extra]>=2.2.2",
|
"lightning[extra]>=2.2.2",
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -89,6 +89,11 @@ def train(
|
|||||||
|
|
||||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
return [
|
return [
|
||||||
|
ModelCheckpoint(
|
||||||
|
dirpath="outputs/checkpoints",
|
||||||
|
save_top_k=1,
|
||||||
|
monitor="total_loss/val",
|
||||||
|
),
|
||||||
ValidationMetrics(
|
ValidationMetrics(
|
||||||
metrics=[
|
metrics=[
|
||||||
DetectionAveragePrecision(),
|
DetectionAveragePrecision(),
|
||||||
@ -97,7 +102,7 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
|||||||
),
|
),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
]
|
]
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user