diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index d48944c..20f6234 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -1,5 +1,6 @@ +import json from pathlib import Path -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch @@ -9,6 +10,14 @@ from soundevent.audio.files import get_audio_files from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config from batdetect2.core import merge_configs +from batdetect2.data import ( + OutputFormatConfig, + build_output_formatter, + get_output_formatter, + load_dataset_from_config, +) +from batdetect2.data.datasets import Dataset +from batdetect2.data.predictions.base import OutputFormatterProtocol from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR @@ -41,6 +50,7 @@ class BatDetect2API: preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, evaluator: EvaluatorProtocol, + formatter: OutputFormatterProtocol, model: Model, ): self.config = config @@ -50,9 +60,17 @@ class BatDetect2API: self.postprocessor = postprocessor self.evaluator = evaluator self.model = model + self.formatter = formatter self.model.eval() + def load_annotations( + self, + path: data.PathLike, + base_dir: Optional[data.PathLike] = None, + ) -> Dataset: + return load_dataset_from_config(path, base_dir=base_dir) + def train( self, train_annotations: Sequence[data.ClipAnnotation], @@ -91,7 +109,8 @@ class BatDetect2API: output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: Optional[str] = None, run_name: Optional[str] = None, - ): + save_predictions: bool = True, + ) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: return evaluate( self.model, test_annotations, @@ -103,8 +122,41 @@ class BatDetect2API: output_dir=output_dir, experiment_name=experiment_name, run_name=run_name, + formatter=self.formatter if save_predictions else None, ) + def evaluate_predictions( + self, + annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[BatDetect2Prediction], + output_dir: Optional[data.PathLike] = None, + ): + clip_evals = self.evaluator.evaluate( + annotations, + predictions, + ) + + metrics = self.evaluator.compute_metrics(clip_evals) + + if output_dir is not None: + output_dir = Path(output_dir) + + if not output_dir.is_dir(): + output_dir.mkdir(parents=True) + + metrics_path = output_dir / "metrics.json" + metrics_path.write_text(json.dumps(metrics)) + + for figure_name, fig in self.evaluator.generate_plots(clip_evals): + fig_path = output_dir / figure_name + + if not fig_path.parent.is_dir(): + fig_path.parent.mkdir(parents=True) + + fig.savefig(fig_path) + + return metrics + def load_audio(self, path: data.PathLike) -> np.ndarray: return self.audio_loader.load_file(path) @@ -194,8 +246,38 @@ class BatDetect2API: config=self.config, ) + def save_predictions( + self, + predictions: Sequence[BatDetect2Prediction], + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + format: Optional[str] = None, + config: Optional[OutputFormatConfig] = None, + ): + formatter = self.formatter + + if format is not None or config is not None: + format = format or config.name # type: ignore + formatter = get_output_formatter( + name=format, + targets=self.targets, + config=config, + ) + + outs = formatter.format(predictions) + formatter.save(outs, audio_dir=audio_dir, path=path) + + def load_predictions( + self, + path: data.PathLike, + ) -> List[BatDetect2Prediction]: + return self.formatter.load(path) + @classmethod - def from_config(cls, config: BatDetect2Config): + def from_config( + cls, + config: BatDetect2Config, + ): targets = build_targets(config=config.targets) audio_loader = build_audio_loader(config=config.audio) @@ -228,6 +310,8 @@ class BatDetect2API: ), ) + formatter = build_output_formatter(targets, config=config.output) + return cls( config=config, targets=targets, @@ -236,6 +320,7 @@ class BatDetect2API: postprocessor=postprocessor, evaluator=evaluator, model=model, + formatter=formatter, ) @classmethod @@ -266,6 +351,8 @@ class BatDetect2API: evaluator = build_evaluator(config=config.evaluation, targets=targets) + formatter = build_output_formatter(targets, config=config.output) + return cls( config=config, targets=targets, @@ -274,4 +361,5 @@ class BatDetect2API: postprocessor=postprocessor, evaluator=evaluator, model=model, + formatter=formatter, ) diff --git a/src/batdetect2/cli/base.py b/src/batdetect2/cli/base.py index 8846685..d0fd5fc 100644 --- a/src/batdetect2/cli/base.py +++ b/src/batdetect2/cli/base.py @@ -1,9 +1,8 @@ """BatDetect2 command line interface.""" -import sys - import click -from loguru import logger + +from batdetect2.logging import enable_logging # from batdetect2.cli.ascii import BATDETECT_ASCII_ART @@ -27,22 +26,9 @@ BatDetect2 - Detection and Classification count=True, help="Increase verbosity. -v for INFO, -vv for DEBUG.", ) -def cli( - verbose: int = 0, -): +def cli(verbose: int = 0): """BatDetect2 - Bat Call Detection and Classification.""" click.echo(INFO_STR) - logger.remove() - - if verbose == 0: - log_level = "WARNING" - elif verbose == 1: - log_level = "INFO" - else: - log_level = "DEBUG" - - logger.add(sys.stderr, level=log_level) - - logger.enable("batdetect2") + enable_logging(verbose) # click.echo(BATDETECT_ASCII_ART) diff --git a/src/batdetect2/cli/data.py b/src/batdetect2/cli/data.py index 64f3757..4b02d8a 100644 --- a/src/batdetect2/cli/data.py +++ b/src/batdetect2/cli/data.py @@ -43,3 +43,55 @@ def summary( ) print(f"Number of annotated clips: {len(dataset)}") + + +@data.command() +@click.argument( + "dataset_config", + type=click.Path(exists=True), +) +@click.option( + "--field", + type=str, + help="If the dataset info is in a nested field please specify here.", +) +@click.option( + "--output", + type=click.Path(exists=False), + default="annotations.json", +) +@click.option( + "--base-dir", + type=click.Path(exists=True), + help="The base directory to which all recording and annotations paths are relative to.", +) +def convert( + dataset_config: Path, + field: Optional[str] = None, + output: Path = Path("annotations.json"), + base_dir: Optional[Path] = None, +): + """Convert a dataset config file to soundevent format.""" + from soundevent import data, io + + from batdetect2.data import load_dataset, load_dataset_config + + base_dir = base_dir or Path.cwd() + + config = load_dataset_config( + dataset_config, + field=field, + ) + + dataset = load_dataset( + config, + base_dir=base_dir, + ) + + annotation_set = data.AnnotationSet( + clip_annotations=list(dataset), + name=config.name, + description=config.description, + ) + + io.save(annotation_set, output) diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 8fef59a..6a17b68 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -4,9 +4,13 @@ from pydantic import Field from soundevent.data import PathLike from batdetect2.audio import AudioConfig -from batdetect2.core import BaseConfig -from batdetect2.core.configs import load_config -from batdetect2.evaluate.config import EvaluationConfig +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.data.predictions import OutputFormatConfig +from batdetect2.data.predictions.raw import RawOutputConfig +from batdetect2.evaluate.config import ( + EvaluationConfig, + get_default_eval_config, +) from batdetect2.inference.config import InferenceConfig from batdetect2.models.config import BackboneConfig from batdetect2.postprocess.config import PostprocessConfig @@ -17,6 +21,7 @@ from batdetect2.train.config import TrainingConfig __all__ = [ "BatDetect2Config", "load_full_config", + "validate_config", ] @@ -24,7 +29,9 @@ class BatDetect2Config(BaseConfig): config_version: Literal["v1"] = "v1" train: TrainingConfig = Field(default_factory=TrainingConfig) - evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) + evaluation: EvaluationConfig = Field( + default_factory=get_default_eval_config + ) model: BackboneConfig = Field(default_factory=BackboneConfig) preprocess: PreprocessingConfig = Field( default_factory=PreprocessingConfig @@ -33,6 +40,14 @@ class BatDetect2Config(BaseConfig): audio: AudioConfig = Field(default_factory=AudioConfig) targets: TargetConfig = Field(default_factory=TargetConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig) + output: OutputFormatConfig = Field(default_factory=RawOutputConfig) + + +def validate_config(config: Optional[dict]) -> BatDetect2Config: + if config is None: + return BatDetect2Config() + + return BatDetect2Config.model_validate(config) def load_full_config( diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index e188935..fd2583c 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -29,7 +29,7 @@ class BaseConfig(BaseModel): and serialization capabilities. """ - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="ignore") def to_yaml_string( self, diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index d059e8c..80852ac 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -77,6 +77,15 @@ class Registry(Generic[T_Type, P_Type]): def get_config_types(self) -> Tuple[Type[BaseModel], ...]: return tuple(self._config_types.values()) + def get_config_type(self, name: str) -> Type[BaseModel]: + try: + return self._config_types[name] + except KeyError as err: + raise ValueError( + f"No config type with name '{name}' is registered. " + f"Existing config types: {list(self._config_types.keys())}" + ) from err + def build( self, config: BaseModel, diff --git a/src/batdetect2/data/__init__.py b/src/batdetect2/data/__init__.py index 54a762a..647a135 100644 --- a/src/batdetect2/data/__init__.py +++ b/src/batdetect2/data/__init__.py @@ -1,5 +1,6 @@ from batdetect2.data.annotations import ( AnnotatedDataset, + AnnotationFormats, AOEFAnnotations, BatDetect2FilesAnnotations, BatDetect2MergedAnnotations, @@ -11,6 +12,14 @@ from batdetect2.data.datasets import ( load_dataset_config, load_dataset_from_config, ) +from batdetect2.data.predictions import ( + BatDetect2OutputConfig, + OutputFormatConfig, + RawOutputConfig, + SoundEventOutputConfig, + build_output_formatter, + get_output_formatter, +) from batdetect2.data.summary import ( compute_class_summary, extract_recordings_df, @@ -20,12 +29,19 @@ from batdetect2.data.summary import ( __all__ = [ "AOEFAnnotations", "AnnotatedDataset", + "AnnotationFormats", "BatDetect2FilesAnnotations", "BatDetect2MergedAnnotations", + "BatDetect2OutputConfig", "DatasetConfig", + "OutputFormatConfig", + "RawOutputConfig", + "SoundEventOutputConfig", + "build_output_formatter", "compute_class_summary", "extract_recordings_df", "extract_sound_events_df", + "get_output_formatter", "load_annotated_dataset", "load_dataset", "load_dataset_config", diff --git a/src/batdetect2/data/annotations/__init__.py b/src/batdetect2/data/annotations/__init__.py index e4dc183..1d499c3 100644 --- a/src/batdetect2/data/annotations/__init__.py +++ b/src/batdetect2/data/annotations/__init__.py @@ -13,7 +13,6 @@ format-specific loading function to retrieve the annotations as a standard `soundevent.data.AnnotationSet`. """ -from pathlib import Path from typing import Annotated, Optional, Union from pydantic import Field @@ -64,7 +63,7 @@ source configuration represents. def load_annotated_dataset( dataset: AnnotatedDataset, - base_dir: Optional[Path] = None, + base_dir: Optional[data.PathLike] = None, ) -> data.AnnotationSet: """Load annotations for a single data source based on its configuration. @@ -97,6 +96,7 @@ def load_annotated_dataset( known format-specific loading functions implemented in the dispatch logic. """ + if isinstance(dataset, AOEFAnnotations): return load_aoef_annotated_dataset(dataset, base_dir=base_dir) diff --git a/src/batdetect2/data/annotations/aoef.py b/src/batdetect2/data/annotations/aoef.py index 748dac8..e73cf60 100644 --- a/src/batdetect2/data/annotations/aoef.py +++ b/src/batdetect2/data/annotations/aoef.py @@ -84,7 +84,7 @@ class AOEFAnnotations(AnnotatedDataset): def load_aoef_annotated_dataset( dataset: AOEFAnnotations, - base_dir: Optional[Path] = None, + base_dir: Optional[data.PathLike] = None, ) -> data.AnnotationSet: """Load annotations from an AnnotationSet or AnnotationProject file. diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index db7f728..3a442ab 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -19,7 +19,7 @@ The core components are: """ from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Sequence from loguru import logger from pydantic import Field @@ -53,7 +53,7 @@ __all__ = [ ] -Dataset = List[data.ClipAnnotation] +Dataset = Sequence[data.ClipAnnotation] """Type alias for a loaded dataset representation. Represents an entire dataset *after loading* as a flat Python list containing @@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig): def load_dataset( config: DatasetConfig, - base_dir: Optional[Path] = None, + base_dir: Optional[data.PathLike] = None, ) -> Dataset: """Load all clip annotations from the sources defined in a DatasetConfig.""" clip_annotations = [] @@ -168,7 +168,7 @@ def load_dataset_config(path: data.PathLike, field: Optional[str] = None): def load_dataset_from_config( path: data.PathLike, field: Optional[str] = None, - base_dir: Optional[Path] = None, + base_dir: Optional[data.PathLike] = None, ) -> Dataset: """Load dataset annotation metadata from a configuration file. @@ -250,6 +250,6 @@ def save_dataset( annotation_set = data.AnnotationSet( name=name, description=description, - clip_annotations=dataset, + clip_annotations=list(dataset), ) io.save(annotation_set, path, audio_dir=audio_dir) diff --git a/src/batdetect2/data/predictions/__init__.py b/src/batdetect2/data/predictions/__init__.py new file mode 100644 index 0000000..8839636 --- /dev/null +++ b/src/batdetect2/data/predictions/__init__.py @@ -0,0 +1,58 @@ +from typing import Annotated, Optional, Union + +from pydantic import Field + +from batdetect2.data.predictions.base import ( + OutputFormatterProtocol, + prediction_formatters, +) +from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig +from batdetect2.data.predictions.raw import RawOutputConfig +from batdetect2.data.predictions.soundevent import SoundEventOutputConfig +from batdetect2.typing import TargetProtocol + +__all__ = [ + "build_output_formatter", + "get_output_formatter", + "BatDetect2OutputConfig", + "RawOutputConfig", + "SoundEventOutputConfig", +] + + +OutputFormatConfig = Annotated[ + Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig], + Field(discriminator="name"), +] + + +def build_output_formatter( + targets: Optional[TargetProtocol] = None, + config: Optional[OutputFormatConfig] = None, +) -> OutputFormatterProtocol: + """Construct the final output formatter.""" + from batdetect2.targets import build_targets + + config = config or RawOutputConfig() + + targets = targets or build_targets() + return prediction_formatters.build(config, targets) + + +def get_output_formatter( + name: str, + targets: Optional[TargetProtocol] = None, + config: Optional[OutputFormatConfig] = None, +) -> OutputFormatterProtocol: + """Get the output formatter by name.""" + + if config is None: + config_class = prediction_formatters.get_config_type(name) + config = config_class() # type: ignore + + if config.name != name: # type: ignore + raise ValueError( + f"Config name {config.name} does not match formatter name {name}" # type: ignore + ) + + return build_output_formatter(targets, config) diff --git a/src/batdetect2/data/predictions/base.py b/src/batdetect2/data/predictions/base.py new file mode 100644 index 0000000..f6f253e --- /dev/null +++ b/src/batdetect2/data/predictions/base.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from soundevent.data import PathLike + +from batdetect2.core import Registry +from batdetect2.typing import ( + OutputFormatterProtocol, + TargetProtocol, +) + + +def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path: + path = Path(path) + audio_dir = Path(audio_dir) + + if path.is_absolute(): + if not path.is_relative_to(audio_dir): + raise ValueError( + f"Audio file {path} is not in audio_dir {audio_dir}" + ) + + return path.relative_to(audio_dir) + + return path + + +prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( + Registry(name="output_formatter") +) diff --git a/src/batdetect2/data/predictions/batdetect2.py b/src/batdetect2/data/predictions/batdetect2.py new file mode 100644 index 0000000..0f4b382 --- /dev/null +++ b/src/batdetect2/data/predictions/batdetect2.py @@ -0,0 +1,227 @@ +import json +from pathlib import Path +from typing import List, Literal, Optional, Sequence, TypedDict + +import numpy as np +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.core import BaseConfig +from batdetect2.data.predictions.base import ( + make_path_relative, + prediction_formatters, +) +from batdetect2.targets import terms +from batdetect2.typing import ( + BatDetect2Prediction, + OutputFormatterProtocol, + RawPrediction, + TargetProtocol, +) + +try: + from typing import NotRequired # type: ignore +except ImportError: + from typing_extensions import NotRequired + +DictWithClass = TypedDict("DictWithClass", {"class": str}) + + +class Annotation(DictWithClass): + """Format of annotations. + + This is the format of a single annotation as expected by the + annotation tool. + """ + + start_time: float + """Start time in seconds.""" + + end_time: float + """End time in seconds.""" + + low_freq: float + """Low frequency in Hz.""" + + high_freq: float + """High frequency in Hz.""" + + class_prob: float + """Probability of class assignment.""" + + det_prob: float + """Probability of detection.""" + + individual: str + """Individual ID.""" + + event: str + """Type of detected event.""" + + +class FileAnnotation(TypedDict): + """Format of results. + + This is the format of the results expected by the annotation tool. + """ + + id: str + """File ID.""" + + annotated: bool + """Whether file has been annotated.""" + + duration: float + """Duration of audio file.""" + + issues: bool + """Whether file has issues.""" + + time_exp: float + """Time expansion factor.""" + + class_name: str + """Class predicted at file level.""" + + notes: str + """Notes of file.""" + + annotation: List[Annotation] + """List of annotations.""" + + file_path: NotRequired[str] + """Path to file.""" + + +class BatDetect2OutputConfig(BaseConfig): + name: Literal["batdetect2"] = "batdetect2" + + event_name: str = "Echolocation" + + annotation_note: str = "Automatically generated." + + +class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): + def __init__( + self, + targets: TargetProtocol, + event_name: str, + annotation_note: str, + ): + self.targets = targets + self.event_name = event_name + self.annotation_note = annotation_note + + def format( + self, predictions: Sequence[BatDetect2Prediction] + ) -> List[FileAnnotation]: + return [ + self.format_prediction(prediction) for prediction in predictions + ] + + def save( + self, + predictions: Sequence[FileAnnotation], + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> None: + path = Path(path) + + if not path.is_dir(): + path.mkdir(parents=True) + + for prediction in predictions: + pred_path = path / (prediction["id"] + ".json") + + if audio_dir is not None and "file_path" in prediction: + prediction["file_path"] = str( + make_path_relative( + prediction["file_path"], + audio_dir, + ) + ) + + pred_path.write_text(json.dumps(prediction)) + + def load(self, path: data.PathLike) -> List[FileAnnotation]: + path = Path(path) + + files = list(path.glob("*.json")) + + if not files: + return [] + + return [ + json.loads(file.read_text()) for file in files if file.is_file() + ] + + def get_recording_class(self, annotations: List[Annotation]) -> str: + """Get class of recording from annotations.""" + + if not annotations: + return "" + + highest_scoring = max(annotations, key=lambda x: x["class_prob"]) + return highest_scoring["class"] + + def format_prediction( + self, prediction: BatDetect2Prediction + ) -> FileAnnotation: + recording = prediction.clip.recording + + annotations = [ + self.format_sound_event_prediction(pred) + for pred in prediction.predictions + ] + + return FileAnnotation( + id=recording.path.name, + file_path=str(recording.path), + annotated=False, + duration=recording.duration, + issues=False, + time_exp=recording.time_expansion, + class_name=self.get_recording_class(annotations), + notes=self.annotation_note, + annotation=annotations, + ) + + def get_class_name(self, class_index: int) -> str: + class_name = self.targets.class_names[class_index] + tags = self.targets.decode_class(class_name) + return data.find_tag_value( + tags, + term=terms.generic_class, + default=class_name, + ) # type: ignore + + def format_sound_event_prediction( + self, prediction: RawPrediction + ) -> Annotation: + start_time, low_freq, end_time, high_freq = compute_bounds( + prediction.geometry + ) + + top_class_index = int(np.argmax(prediction.class_scores)) + top_class_score = float(prediction.class_scores[top_class_index]) + top_class = self.get_class_name(top_class_index) + return Annotation( + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + class_prob=top_class_score, + det_prob=float(prediction.detection_score), + individual="", + event=self.event_name, + **{"class": top_class}, + ) + + @prediction_formatters.register(BatDetect2OutputConfig) + @staticmethod + def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol): + return BatDetect2Formatter( + targets, + event_name=config.event_name, + annotation_note=config.annotation_note, + ) diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py new file mode 100644 index 0000000..715a595 --- /dev/null +++ b/src/batdetect2/data/predictions/raw.py @@ -0,0 +1,246 @@ +from collections import defaultdict +from pathlib import Path +from typing import List, Literal, Optional, Sequence +from uuid import UUID, uuid4 + +import numpy as np +import xarray as xr +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.core import BaseConfig +from batdetect2.data.predictions.base import ( + make_path_relative, + prediction_formatters, +) +from batdetect2.typing import ( + BatDetect2Prediction, + OutputFormatterProtocol, + RawPrediction, + TargetProtocol, +) + + +class RawOutputConfig(BaseConfig): + name: Literal["raw"] = "raw" + + include_class_scores: bool = True + include_features: bool = True + include_geometry: bool = True + + +class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): + def __init__( + self, + targets: TargetProtocol, + include_class_scores: bool = True, + include_features: bool = True, + include_geometry: bool = True, + ): + self.targets = targets + self.include_class_scores = include_class_scores + self.include_features = include_features + self.include_geometry = include_geometry + + def format( + self, + predictions: Sequence[BatDetect2Prediction], + ) -> List[BatDetect2Prediction]: + return list(predictions) + + def save( + self, + predictions: Sequence[BatDetect2Prediction], + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> None: + num_features = 0 + + tree = xr.DataTree() + + for prediction in predictions: + clip = prediction.clip + recording = clip.recording + + if audio_dir is not None: + recording = recording.model_copy( + update=dict( + path=make_path_relative(recording.path, audio_dir) + ) + ) + + clip_data = defaultdict(list) + + for pred in prediction.predictions: + detection_id = str(uuid4()) + + clip_data["detection_id"].append(detection_id) + clip_data["detection_score"].append(pred.detection_score) + + start_time, low_freq, end_time, high_freq = compute_bounds( + pred.geometry + ) + + clip_data["start_time"].append(start_time) + clip_data["end_time"].append(end_time) + clip_data["low_freq"].append(low_freq) + clip_data["high_freq"].append(high_freq) + + clip_data["geometry"].append(pred.geometry.model_dump_json()) + + top_class_index = int(np.argmax(pred.class_scores)) + top_class_score = float(pred.class_scores[top_class_index]) + top_class = self.targets.class_names[top_class_index] + + clip_data["top_class"].append(top_class) + clip_data["top_class_score"].append(top_class_score) + + clip_data["class_scores"].append(pred.class_scores) + clip_data["features"].append(pred.features) + + num_features = len(pred.features) + + data_vars = { + "score": (["detection"], clip_data["detection_score"]), + "start_time": (["detection"], clip_data["start_time"]), + "end_time": (["detection"], clip_data["end_time"]), + "low_freq": (["detection"], clip_data["low_freq"]), + "high_freq": (["detection"], clip_data["high_freq"]), + "top_class": (["detection"], clip_data["top_class"]), + "top_class_score": ( + ["detection"], + clip_data["top_class_score"], + ), + } + + coords = { + "detection": ("detection", clip_data["detection_id"]), + "clip_start": clip.start_time, + "clip_end": clip.end_time, + "clip_id": str(clip.uuid), + } + + if self.include_class_scores: + data_vars["class_scores"] = ( + ["detection", "classes"], + clip_data["class_scores"], + ) + coords["classes"] = ("classes", self.targets.class_names) + + if self.include_features: + data_vars["features"] = ( + ["detection", "feature"], + clip_data["features"], + ) + coords["feature"] = ("feature", np.arange(num_features)) + + if self.include_geometry: + data_vars["geometry"] = (["detection"], clip_data["geometry"]) + + dataset = xr.Dataset( + data_vars=data_vars, + coords=coords, + attrs={ + "recording": recording.model_dump_json(exclude_none=True), + }, + ) + + tree = tree.assign( + { + str(clip.uuid): xr.DataTree( + dataset=dataset, + name=str(clip.uuid), + ) + } + ) + + path = Path(path) + + if not path.suffix == ".nc": + path = Path(path).with_suffix(".nc") + + tree.to_netcdf(path) + + def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: + path = Path(path) + + root = xr.load_datatree(path) + + predictions: List[BatDetect2Prediction] = [] + + for _, clip_data in root.items(): + recording = data.Recording.model_validate_json( + clip_data.attrs["recording"] + ) + + clip_id = clip_data.clip_id.item() + clip = data.Clip( + recording=recording, + uuid=UUID(clip_id), + start_time=clip_data.clip_start, + end_time=clip_data.clip_end, + ) + + sound_events = [] + + for detection in clip_data.detection: + score = clip_data.score.sel(detection=detection).item() + + if "geometry" in clip_data: + geometry = data.geometry_validate( + clip_data.geometry.sel(detection=detection).item() + ) + else: + start_time = clip_data.start_time.sel(detection=detection) + end_time = clip_data.end_time.sel(detection=detection) + low_freq = clip_data.low_freq.sel(detection=detection) + high_freq = clip_data.high_freq.sel(detection=detection) + geometry = data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq] + ) + + if "class_scores" in clip_data: + class_scores = clip_data.class_scores.sel( + detection=detection + ).data + else: + class_scores = np.zeros(len(self.targets.class_names)) + class_index = self.targets.class_names.index( + clip_data.top_class.sel(detection=detection).item() + ) + class_scores[class_index] = clip_data.top_class_score.sel( + detection=detection + ).item() + + if "features" in clip_data: + features = clip_data.features.sel(detection=detection).data + else: + features = np.zeros(0) + + sound_events.append( + RawPrediction( + geometry=geometry, + detection_score=score, + class_scores=class_scores, + features=features, + ) + ) + + predictions.append( + BatDetect2Prediction( + clip=clip, + predictions=sound_events, + ) + ) + + return predictions + + @prediction_formatters.register(RawOutputConfig) + @staticmethod + def from_config(config: RawOutputConfig, targets: TargetProtocol): + return RawFormatter( + targets, + include_class_scores=config.include_class_scores, + include_features=config.include_features, + include_geometry=config.include_geometry, + ) diff --git a/src/batdetect2/data/predictions/soundevent.py b/src/batdetect2/data/predictions/soundevent.py new file mode 100644 index 0000000..f400f42 --- /dev/null +++ b/src/batdetect2/data/predictions/soundevent.py @@ -0,0 +1,131 @@ +from pathlib import Path +from typing import List, Literal, Optional, Sequence + +import numpy as np +from soundevent import data, io + +from batdetect2.core import BaseConfig +from batdetect2.data.predictions.base import ( + prediction_formatters, +) +from batdetect2.typing import ( + BatDetect2Prediction, + OutputFormatterProtocol, + RawPrediction, + TargetProtocol, +) + + +class SoundEventOutputConfig(BaseConfig): + name: Literal["soundevent"] = "soundevent" + top_k: Optional[int] = 1 + min_score: Optional[float] = None + + +class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): + def __init__( + self, + targets: TargetProtocol, + top_k: Optional[int] = 1, + min_score: Optional[float] = 0, + ): + self.targets = targets + self.top_k = top_k + self.min_score = min_score + + def format( + self, + predictions: Sequence[BatDetect2Prediction], + ) -> List[data.ClipPrediction]: + return [ + self.format_prediction(prediction) for prediction in predictions + ] + + def save( + self, + predictions: Sequence[data.ClipPrediction], + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> None: + run = data.PredictionSet(clip_predictions=list(predictions)) + + path = Path(path) + + if not path.suffix == ".json": + path = Path(path).with_suffix(".json") + + io.save(run, path, audio_dir=audio_dir) + + def load(self, path: data.PathLike) -> List[data.ClipPrediction]: + path = Path(path) + run = io.load(path, type="prediction_set") + return run.clip_predictions + + def format_prediction( + self, + prediction: BatDetect2Prediction, + ) -> data.ClipPrediction: + recording = prediction.clip.recording + return data.ClipPrediction( + clip=prediction.clip, + sound_events=[ + self.format_sound_event_prediction(pred, recording) + for pred in prediction.predictions + ], + ) + + def format_sound_event_prediction( + self, + prediction: RawPrediction, + recording: data.Recording, + ) -> data.SoundEventPrediction: + return data.SoundEventPrediction( + sound_event=data.SoundEvent( + recording=recording, + geometry=prediction.geometry, + ), + score=prediction.detection_score, + tags=self.get_sound_event_tags(prediction), + ) + + def get_sound_event_tags( + self, prediction: RawPrediction + ) -> List[data.PredictedTag]: + sorted_indices = np.argsort(prediction.class_scores)[::-1] + + tags = [ + data.PredictedTag( + tag=tag, + score=prediction.detection_score, + ) + for tag in self.targets.detection_class_tags + ] + + top_k = self.top_k or len(sorted_indices) + + for ind in sorted_indices[:top_k]: + score = float(prediction.class_scores[ind]) + + if self.min_score is not None and score < self.min_score: + break + + class_name = self.targets.class_names[ind] + class_tags = self.targets.decode_class(class_name) + tags.extend( + data.PredictedTag( + tag=tag, + score=score, + ) + for tag in class_tags + ) + + return tags + + @prediction_formatters.register(SoundEventOutputConfig) + @staticmethod + def from_config(config: SoundEventOutputConfig, targets: TargetProtocol): + return SoundEventOutputFormatter( + targets, + top_k=config.top_k, + min_score=config.min_score, + ) diff --git a/src/batdetect2/data/summary.py b/src/batdetect2/data/summary.py index f4994d5..d550d9d 100644 --- a/src/batdetect2/data/summary.py +++ b/src/batdetect2/data/summary.py @@ -159,6 +159,7 @@ def compute_class_summary( exclude_generic=False, exclude_non_target=True, ) + recordings = extract_recordings_df(dataset) num_calls = ( diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 4d5510c..fbcb281 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -27,6 +27,28 @@ class EvaluationConfig(BaseConfig): logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) +def get_default_eval_config() -> EvaluationConfig: + return EvaluationConfig.model_validate( + { + "tasks": [ + { + "name": "sound_event_detection", + "plots": [ + {"name": "pr_curve"}, + {"name": "score_distribution"}, + ], + }, + { + "name": "sound_event_classification", + "plots": [ + {"name": "pr_curve"}, + ], + }, + ] + } + ) + + def load_evaluation_config( path: data.PathLike, field: Optional[str] = None, diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 7312639..0029b7b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple from lightning import Trainer from soundevent import data @@ -12,11 +12,13 @@ from batdetect2.logging import build_logger from batdetect2.models import Model from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets +from batdetect2.typing.postprocess import RawPrediction if TYPE_CHECKING: from batdetect2.config import BatDetect2Config from batdetect2.typing import ( AudioLoader, + OutputFormatterProtocol, PreprocessorProtocol, TargetProtocol, ) @@ -31,11 +33,12 @@ def evaluate( audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, + formatter: Optional["OutputFormatterProtocol"] = None, num_workers: Optional[int] = None, output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: Optional[str] = None, run_name: Optional[str] = None, -): +) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: from batdetect2.config import BatDetect2Config config = config or BatDetect2Config() @@ -66,4 +69,12 @@ def evaluate( ) module = EvaluationModule(model, evaluator) trainer = Trainer(logger=logger, enable_checkpointing=False) - return trainer.test(module, loader) + metrics = trainer.test(module, loader) + + if formatter is not None and logger.log_dir is not None: + formatter.save( + module.predictions, + path=Path(logger.log_dir) / "predictions", + ) + + return metrics, module.predictions # type: ignore diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 8126dda..181994b 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -6,7 +6,8 @@ from soundevent import data from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.tasks import build_task from batdetect2.targets import build_targets -from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol +from batdetect2.typing import EvaluatorProtocol, TargetProtocol +from batdetect2.typing.postprocess import BatDetect2Prediction __all__ = [ "Evaluator", @@ -26,7 +27,7 @@ class Evaluator: def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], + predictions: Sequence[BatDetect2Prediction], ) -> List[Any]: return [ task.evaluate(clip_annotations, predictions) for task in self.tasks diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index da2e3e3..f7dac85 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions from batdetect2.typing import EvaluatorProtocol -from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.postprocess import BatDetect2Prediction class EvaluationModule(LightningModule): @@ -24,7 +24,7 @@ class EvaluationModule(LightningModule): self.evaluator = evaluator self.clip_annotations: List[data.ClipAnnotation] = [] - self.predictions: List[List[RawPrediction]] = [] + self.predictions: List[BatDetect2Prediction] = [] def test_step(self, batch: TestExample, batch_idx: int): dataset = self.get_dataset() @@ -39,11 +39,16 @@ class EvaluationModule(LightningModule): start_times=[ca.clip.start_time for ca in clip_annotations], ) predictions = [ - to_raw_predictions( - clip_dets.numpy(), - targets=self.evaluator.targets, + BatDetect2Prediction( + clip=clip_annotation.clip, + predictions=to_raw_predictions( + clip_dets.numpy(), + targets=self.evaluator.targets, + ), + ) + for clip_annotation, clip_dets in zip( + clip_annotations, clip_detections ) - for clip_dets in clip_detections ] self.clip_annotations.extend(clip_annotations) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 9a3ea8a..3835f86 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -23,7 +23,7 @@ from batdetect2.evaluate.match import ( build_matcher, ) from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol -from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction from batdetect2.typing.targets import TargetProtocol __all__ = [ @@ -99,7 +99,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], + predictions: Sequence[BatDetect2Prediction], ) -> List[T_Output]: return [ self.evaluate_clip(clip_annotation, preds) @@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> T_Output: ... def include_sound_event_annotation( diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 5d63c3f..9a31db2 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -1,7 +1,6 @@ from typing import ( List, Literal, - Sequence, ) from pydantic import Field @@ -23,7 +22,7 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import BatDetect2Prediction, TargetProtocol class ClassificationTaskConfig(BaseTaskConfig): @@ -49,12 +48,14 @@ class ClassificationTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> ClipEval: clip = clip_annotation.clip preds = [ - pred for pred in predictions if self.include_prediction(pred, clip) + pred + for pred in prediction.predictions + if self.include_prediction(pred, clip) ] all_gts = [ diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 798f79b..8215555 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import List, Literal, Sequence +from typing import List, Literal from pydantic import Field from soundevent import data @@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import TargetProtocol +from batdetect2.typing.postprocess import BatDetect2Prediction class ClipClassificationTaskConfig(BaseTaskConfig): @@ -37,7 +38,7 @@ class ClipClassificationTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> ClipEval: clip = clip_annotation.clip @@ -54,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]): gt_classes.add(class_name) pred_scores = defaultdict(float) - for pred in predictions: + for pred in prediction.predictions: if not self.include_prediction(pred, clip): continue diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index 2fb60a7..66a7a9d 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Sequence +from typing import List, Literal from pydantic import Field from soundevent import data @@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import TargetProtocol +from batdetect2.typing.postprocess import BatDetect2Prediction class ClipDetectionTaskConfig(BaseTaskConfig): @@ -36,7 +37,7 @@ class ClipDetectionTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> ClipEval: clip = clip_annotation.clip @@ -46,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]): ) pred_score = 0 - for pred in predictions: + for pred in prediction.predictions: if not self.include_prediction(pred, clip): continue diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 0c13914..8e404d9 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Sequence +from typing import List, Literal from pydantic import Field from soundevent import data @@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import TargetProtocol +from batdetect2.typing.postprocess import BatDetect2Prediction class DetectionTaskConfig(BaseTaskConfig): @@ -35,7 +36,7 @@ class DetectionTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> ClipEval: clip = clip_annotation.clip @@ -45,7 +46,9 @@ class DetectionTask(BaseTask[ClipEval]): if self.include_sound_event_annotation(sound_event, clip) ] preds = [ - pred for pred in predictions if self.include_prediction(pred, clip) + pred + for pred in prediction.predictions + if self.include_prediction(pred, clip) ] scores = [pred.detection_score for pred in preds] diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index 78533d8..ef94041 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Sequence +from typing import List, Literal from pydantic import Field from soundevent import data @@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import ( BaseTaskConfig, tasks_registry, ) -from batdetect2.typing import RawPrediction, TargetProtocol +from batdetect2.typing import TargetProtocol +from batdetect2.typing.postprocess import BatDetect2Prediction class TopClassDetectionTaskConfig(BaseTaskConfig): @@ -35,7 +36,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]): def evaluate_clip( self, clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], + prediction: BatDetect2Prediction, ) -> ClipEval: clip = clip_annotation.clip @@ -45,7 +46,9 @@ class TopClassDetectionTask(BaseTask[ClipEval]): if self.include_sound_event_annotation(sound_event, clip) ] preds = [ - pred for pred in predictions if self.include_prediction(pred, clip) + pred + for pred in prediction.predictions + if self.include_prediction(pred, clip) ] # Take the highest score for each prediction scores = [pred.class_scores.max() for pred in preds] diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index 67bf11d..a423ca2 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -1,4 +1,5 @@ import io +import sys from collections.abc import Callable from functools import partial from pathlib import Path @@ -32,6 +33,20 @@ from batdetect2.core.configs import BaseConfig DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" +def enable_logging(level: int): + logger.remove() + + if level == 0: + log_level = "WARNING" + elif level == 1: + log_level = "INFO" + else: + log_level = "DEBUG" + + logger.add(sys.stderr, level=log_level) + logger.enable("batdetect2") + + class BaseLoggerConfig(BaseConfig): log_dir: Path = DEFAULT_LOGS_DIR experiment_name: Optional[str] = None diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 1a74eb9..7ab11e7 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -30,10 +30,7 @@ from typing import List, Optional import torch -from batdetect2.models.backbones import ( - Backbone, - build_backbone, -) +from batdetect2.models.backbones import Backbone, build_backbone from batdetect2.models.blocks import ( ConvConfig, FreqCoordConvDownConfig, @@ -62,16 +59,13 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead -from batdetect2.postprocess import build_postprocessor -from batdetect2.preprocess import build_preprocessor -from batdetect2.targets import build_targets -from batdetect2.typing.models import DetectionModel -from batdetect2.typing.postprocess import ( +from batdetect2.typing import ( ClipDetectionsTensor, + DetectionModel, PostprocessorProtocol, + PreprocessorProtocol, + TargetProtocol, ) -from batdetect2.typing.preprocess import PreprocessorProtocol -from batdetect2.typing.targets import TargetProtocol __all__ = [ "BBoxHead", @@ -133,6 +127,10 @@ def build_model( preprocessor: Optional[PreprocessorProtocol] = None, postprocessor: Optional[PostprocessorProtocol] = None, ): + from batdetect2.postprocess import build_postprocessor + from batdetect2.preprocess import build_preprocessor + from batdetect2.targets import build_targets + config = config or BackboneConfig() targets = targets or build_targets() preprocessor = preprocessor or build_preprocessor() diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index 1168a96..06f6321 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially intermediate features. """ -from typing import List +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -30,6 +30,55 @@ __all__ = [ ] +def to_xarray( + array: Union[torch.Tensor, np.ndarray], + start_time: float, + end_time: float, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, + name: str = "xarray", + extra_dims: Optional[List[str]] = None, + extra_coords: Optional[Dict[str, np.ndarray]] = None, +) -> xr.DataArray: + if isinstance(array, torch.Tensor): + array = array.detach().cpu().numpy() + + extra_ndims = array.ndim - 2 + + if extra_ndims < 0: + raise ValueError( + "Input array must have at least 2 dimensions, " + f"got shape {array.shape}" + ) + + width = array.shape[-1] + height = array.shape[-2] + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + if extra_dims is None: + extra_dims = [f"dim_{i}" for i in range(extra_ndims)] + + if extra_coords is None: + extra_coords = {} + + return xr.DataArray( + data=array, + dims=[ + *extra_dims, + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + **extra_coords, + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + name=name, + ) + + def map_detection_to_clip( detections: ClipDetectionsTensor, start_time: float, diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index fc24cf1..bac1a53 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -10,9 +10,9 @@ from batdetect2.postprocess import to_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.typing import ( + BatDetect2Prediction, EvaluatorProtocol, ModelOutput, - RawPrediction, TrainExample, ) @@ -24,7 +24,7 @@ class ValidationMetrics(Callback): self.evaluator = evaluator self._clip_annotations: List[data.ClipAnnotation] = [] - self._predictions: List[List[RawPrediction]] = [] + self._predictions: List[BatDetect2Prediction] = [] def get_dataset(self, trainer: Trainer) -> ValidationDataset: dataloaders = trainer.val_dataloaders @@ -100,8 +100,15 @@ class ValidationMetrics(Callback): start_times=[ca.clip.start_time for ca in clip_annotations], ) predictions = [ - to_raw_predictions(clip_dets.numpy(), targets=model.targets) - for clip_dets in clip_detections + BatDetect2Prediction( + clip=clip_annotation.clip, + predictions=to_raw_predictions( + clip_dets.numpy(), targets=model.targets + ), + ) + for clip_annotation, clip_dets in zip( + clip_annotations, clip_detections + ) ] self._clip_annotations.extend(clip_annotations) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 2527212..92d1468 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -6,11 +6,10 @@ from soundevent.data import PathLike from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.models import Model, build_model -from batdetect2.plotting.clips import build_preprocessor from batdetect2.postprocess import build_postprocessor -from batdetect2.targets.targets import build_targets +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets from batdetect2.train.losses import build_loss from batdetect2.typing import ModelOutput, TrainExample @@ -27,20 +26,20 @@ class TrainingModule(L.LightningModule): def __init__( self, - config: "BatDetect2Config", - input_samplerate: int = TARGET_SAMPLERATE_HZ, - learning_rate: float = 0.001, + config: Optional[dict] = None, t_max: int = 100, model: Optional[Model] = None, loss: Optional[torch.nn.Module] = None, ): + from batdetect2.config import validate_config + super().__init__() self.save_hyperparameters(logger=False) - self.input_samplerate = input_samplerate - self.config = config - self.learning_rate = learning_rate + self.config = validate_config(config) + self.input_samplerate = self.config.audio.samplerate + self.learning_rate = self.config.train.optimizer.learning_rate self.t_max = t_max if loss is None: @@ -104,14 +103,7 @@ def load_model_from_checkpoint( def build_training_module( - config: Optional["BatDetect2Config"] = None, + config: Optional[dict] = None, t_max: int = 200, ) -> TrainingModule: - from batdetect2.config import BatDetect2Config - - config = config or BatDetect2Config() - return TrainingModule( - config=config, - learning_rate=config.train.optimizer.learning_rate, - t_max=t_max, - ) + return TrainingModule(config=config, t_max=t_max) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 5673d0d..82350c6 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -97,7 +97,7 @@ def train( ) module = build_training_module( - config, + config.model_dump(mode="json"), t_max=config.train.optimizer.t_max * len(train_dataloader), ) diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index 60697e0..6269528 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -1,3 +1,4 @@ +from batdetect2.typing.data import OutputFormatterProtocol from batdetect2.typing.evaluate import ( AffinityFunction, ClipMatches, @@ -10,6 +11,7 @@ from batdetect2.typing.evaluate import ( from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.postprocess import ( BatDetect2Prediction, + ClipDetectionsTensor, GeometryDecoder, PostprocessorProtocol, RawPrediction, @@ -43,8 +45,9 @@ __all__ = [ "Augmentation", "BackboneModel", "BatDetect2Prediction", - "ClipMatches", + "ClipDetectionsTensor", "ClipLabeller", + "ClipMatches", "ClipperProtocol", "DetectionModel", "EvaluatorProtocol", @@ -56,6 +59,7 @@ __all__ = [ "MatcherProtocol", "MetricsProtocol", "ModelOutput", + "OutputFormatterProtocol", "PlotterProtocol", "Position", "PostprocessorProtocol", diff --git a/src/batdetect2/typing/data.py b/src/batdetect2/typing/data.py new file mode 100644 index 0000000..52184ce --- /dev/null +++ b/src/batdetect2/typing/data.py @@ -0,0 +1,26 @@ +from typing import Generic, List, Optional, Protocol, Sequence, TypeVar + +from soundevent.data import PathLike + +from batdetect2.typing.postprocess import BatDetect2Prediction + +__all__ = [ + "OutputFormatterProtocol", +] + +T = TypeVar("T") + + +class OutputFormatterProtocol(Protocol, Generic[T]): + def format( + self, predictions: Sequence[BatDetect2Prediction] + ) -> List[T]: ... + + def save( + self, + predictions: Sequence[T], + path: PathLike, + audio_dir: Optional[PathLike] = None, + ) -> None: ... + + def load(self, path: PathLike) -> List[T]: ... diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 8c22c52..6c1fbfb 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -14,7 +14,7 @@ from typing import ( from matplotlib.figure import Figure from soundevent import data -from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction from batdetect2.typing.targets import TargetProtocol __all__ = [ @@ -115,7 +115,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]): def evaluate( self, clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], + predictions: Sequence[BatDetect2Prediction], ) -> EvaluationOutput: ... def compute_metrics(