mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Compare commits
No commits in common. "b997a122f1e41e70bf686a8ed72e6fce776c3dd0" and "65d13a32b76fc5868be273d5dabaed7eaabab1e8" have entirely different histories.
b997a122f1
...
65d13a32b7
@ -1,7 +1,6 @@
|
|||||||
"""BatDetect2 command line interface."""
|
"""BatDetect2 command line interface."""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||||
|
|
||||||
@ -22,5 +21,4 @@ BatDetect2 - Detection and Classification
|
|||||||
def cli():
|
def cli():
|
||||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||||
click.echo(INFO_STR)
|
click.echo(INFO_STR)
|
||||||
logger.enable("batdetect2")
|
|
||||||
# click.echo(BATDETECT_ASCII_ART)
|
# click.echo(BATDETECT_ASCII_ART)
|
||||||
|
|||||||
@ -1,15 +1,9 @@
|
|||||||
from batdetect2.evaluate.config import (
|
|
||||||
EvaluationConfig,
|
|
||||||
load_evaluation_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import (
|
||||||
match_predictions_and_annotations,
|
match_predictions_and_annotations,
|
||||||
match_sound_events_and_raw_predictions,
|
match_sound_events_and_raw_predictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
|
||||||
"load_evaluation_config",
|
|
||||||
"match_predictions_and_annotations",
|
|
||||||
"match_sound_events_and_raw_predictions",
|
"match_sound_events_and_raw_predictions",
|
||||||
|
"match_predictions_and_annotations",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,25 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"EvaluationConfig",
|
|
||||||
"load_evaluation_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
|
||||||
match: MatchConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> EvaluationConfig:
|
|
||||||
return load_config(path, schema=EvaluationConfig, field=field)
|
|
||||||
@ -19,7 +19,7 @@ class ClassExamples:
|
|||||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def plot_example_gallery(
|
def plot_examples(
|
||||||
matches: List[MatchEvaluation],
|
matches: List[MatchEvaluation],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
n_examples: int = 5,
|
n_examples: int = 5,
|
||||||
|
|||||||
@ -179,7 +179,7 @@ def plot_false_positive_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
@ -326,7 +326,7 @@ def plot_true_positive_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
|
|||||||
@ -1,21 +1,14 @@
|
|||||||
from functools import partial
|
from typing import List
|
||||||
from multiprocessing import Pool
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
from loguru import logger
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||||
MatchConfig,
|
|
||||||
match_sound_events_and_raw_predictions,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_examples
|
||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -23,25 +16,16 @@ from batdetect2.train.types import ModelOutput
|
|||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(
|
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
||||||
self,
|
|
||||||
metrics: List[MetricsProtocol],
|
|
||||||
plot: bool = True,
|
|
||||||
match_config: Optional[MatchConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
raise ValueError("At least one metric needs to be provided")
|
raise ValueError("At least one metric needs to be provided")
|
||||||
|
|
||||||
self.match_config = match_config
|
self.matches: List[MatchEvaluation] = []
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
||||||
self._matches: List[
|
|
||||||
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
|
||||||
] = []
|
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
@ -49,33 +33,25 @@ class ValidationMetrics(Callback):
|
|||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, LabeledDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(
|
def plot_examples(self, pl_module: LightningModule):
|
||||||
self,
|
|
||||||
pl_module: LightningModule,
|
|
||||||
matches: List[MatchEvaluation],
|
|
||||||
):
|
|
||||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||||
return
|
return
|
||||||
|
|
||||||
for class_name, fig in plot_example_gallery(
|
for class_name, fig in plot_examples(
|
||||||
matches,
|
self.matches,
|
||||||
preprocessor=pl_module.preprocessor,
|
preprocessor=pl_module.preprocessor,
|
||||||
n_examples=5,
|
n_examples=5,
|
||||||
):
|
):
|
||||||
pl_module.logger.experiment.add_figure(
|
pl_module.logger.experiment.add_figure(
|
||||||
f"images/{class_name}_examples",
|
f"{class_name}/examples",
|
||||||
fig,
|
fig,
|
||||||
pl_module.global_step,
|
pl_module.global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_metrics(
|
def log_metrics(self, pl_module: LightningModule):
|
||||||
self,
|
|
||||||
pl_module: LightningModule,
|
|
||||||
matches: List[MatchEvaluation],
|
|
||||||
):
|
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
metrics.update(metric(matches).items())
|
metrics.update(metric(self.matches).items())
|
||||||
|
|
||||||
pl_module.log_dict(metrics)
|
pl_module.log_dict(metrics)
|
||||||
|
|
||||||
@ -84,16 +60,10 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
matches = _match_all_collected_examples(
|
self.log_metrics(pl_module)
|
||||||
self._matches,
|
|
||||||
pl_module.targets,
|
|
||||||
config=self.match_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_metrics(pl_module, matches)
|
|
||||||
|
|
||||||
if self.plot:
|
if self.plot:
|
||||||
self.plot_examples(pl_module, matches)
|
self.plot_examples(pl_module)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
@ -102,7 +72,7 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._matches = []
|
self.matches = []
|
||||||
return super().on_validation_epoch_start(trainer, pl_module)
|
return super().on_validation_epoch_start(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_batch_end( # type: ignore
|
def on_validation_batch_end( # type: ignore
|
||||||
@ -140,26 +110,13 @@ class ValidationMetrics(Callback):
|
|||||||
for clip_annotation, clip_predictions in zip(
|
for clip_annotation, clip_predictions in zip(
|
||||||
clip_annotations, raw_predictions
|
clip_annotations, raw_predictions
|
||||||
):
|
):
|
||||||
self._matches.append((clip_annotation, clip_predictions))
|
self.matches.extend(
|
||||||
|
match_sound_events_and_raw_predictions(
|
||||||
|
clip_annotation=clip_annotation,
|
||||||
def _match_all_collected_examples(
|
raw_predictions=clip_predictions,
|
||||||
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
|
targets=pl_module.targets,
|
||||||
targets: TargetProtocol,
|
)
|
||||||
config: Optional[MatchConfig] = None,
|
)
|
||||||
) -> List[MatchEvaluation]:
|
|
||||||
logger.info("Matching all annotations and predictions")
|
|
||||||
|
|
||||||
with Pool() as p:
|
|
||||||
matches = p.starmap(
|
|
||||||
partial(
|
|
||||||
match_sound_events_and_raw_predictions,
|
|
||||||
targets=targets,
|
|
||||||
config=config,
|
|
||||||
),
|
|
||||||
pre_matches,
|
|
||||||
)
|
|
||||||
return [match for clip_matches in matches for match in clip_matches]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_in_subclip(
|
def _is_in_subclip(
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
|
||||||
from batdetect2.models import BackboneConfig
|
from batdetect2.models import BackboneConfig
|
||||||
from batdetect2.postprocess import PostprocessConfig
|
from batdetect2.postprocess import PostprocessConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
@ -95,7 +94,6 @@ class FullTrainingConfig(BaseConfig):
|
|||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_full_training_config(
|
def load_full_training_config(
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAccuracy,
|
ClassificationAccuracy,
|
||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
@ -83,9 +82,7 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
targets: TargetProtocol, config: EvaluationConfig
|
|
||||||
) -> List[Callback]:
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath="outputs/checkpoints",
|
dirpath="outputs/checkpoints",
|
||||||
@ -99,8 +96,7 @@ def build_trainer_callbacks(
|
|||||||
class_names=targets.class_names
|
class_names=targets.class_names
|
||||||
),
|
),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
],
|
]
|
||||||
match_config=config.match,
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -117,7 +113,7 @@ def build_trainer(
|
|||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=build_logger(conf.train.logger),
|
logger=build_logger(conf.train.logger),
|
||||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
callbacks=build_trainer_callbacks(targets),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user