diff --git a/batdetect2/evaluate/__init__.py b/batdetect2/evaluate/__init__.py index 58c0b3f..d9235df 100644 --- a/batdetect2/evaluate/__init__.py +++ b/batdetect2/evaluate/__init__.py @@ -1,9 +1,13 @@ from batdetect2.evaluate.evaluate import ( compute_error_auc, +) +from batdetect2.evaluate.match import ( match_predictions_and_annotations, + match_sound_events_and_raw_predictions, ) __all__ = [ "compute_error_auc", "match_predictions_and_annotations", + "match_sound_events_and_raw_predictions", ] diff --git a/batdetect2/evaluate/evaluate.py b/batdetect2/evaluate/evaluate.py index af8f603..7e8c33d 100755 --- a/batdetect2/evaluate/evaluate.py +++ b/batdetect2/evaluate/evaluate.py @@ -1,51 +1,6 @@ -from typing import List import numpy as np from sklearn.metrics import auc, roc_curve -from soundevent import data -from soundevent.evaluation import match_geometries - - -def match_predictions_and_annotations( - clip_annotation: data.ClipAnnotation, - clip_prediction: data.ClipPrediction, -) -> List[data.Match]: - annotated_sound_events = [ - sound_event_annotation - for sound_event_annotation in clip_annotation.sound_events - if sound_event_annotation.sound_event.geometry is not None - ] - - predicted_sound_events = [ - sound_event_prediction - for sound_event_prediction in clip_prediction.sound_events - if sound_event_prediction.sound_event.geometry is not None - ] - - annotated_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry - for sound_event in annotated_sound_events - if sound_event.sound_event.geometry is not None - ] - - predicted_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry - for sound_event in predicted_sound_events - if sound_event.sound_event.geometry is not None - ] - - matches = [] - for id1, id2, affinity in match_geometries( - annotated_geometries, - predicted_geometries, - ): - target = annotated_sound_events[id1] if id1 is not None else None - source = predicted_sound_events[id2] if id2 is not None else None - matches.append( - data.Match(source=source, target=target, affinity=affinity) - ) - - return matches def compute_error_auc(op_str, gt, pred, prob): diff --git a/batdetect2/evaluate/match.py b/batdetect2/evaluate/match.py new file mode 100644 index 0000000..ccae73f --- /dev/null +++ b/batdetect2/evaluate/match.py @@ -0,0 +1,111 @@ +from typing import List + +from soundevent import data +from soundevent.evaluation import match_geometries + +from batdetect2.evaluate.types import Match +from batdetect2.postprocess.types import RawPrediction +from batdetect2.targets.types import TargetProtocol +from batdetect2.utils.arrays import iterate_over_array + + +def match_sound_events_and_raw_predictions( + sound_events: List[data.SoundEventAnnotation], + raw_predictions: List[RawPrediction], + targets: TargetProtocol, +) -> List[Match]: + target_sound_events = [ + targets.transform(sound_event_annotation) + for sound_event_annotation in sound_events + if targets.filter(sound_event_annotation) + and sound_event_annotation.sound_event.geometry is not None + ] + + target_geometries: List[data.Geometry] = [ # type: ignore + sound_event_annotation.sound_event.geometry + for sound_event_annotation in target_sound_events + ] + + predicted_geometries = [ + raw_prediction.geometry for raw_prediction in raw_predictions + ] + + matches = [] + for id1, id2, affinity in match_geometries( + target_geometries, + predicted_geometries, + ): + target = target_sound_events[id1] if id1 is not None else None + prediction = raw_predictions[id2] if id2 is not None else None + + gt_uuid = target.uuid if target is not None else None + gt_det = target is not None + gt_class = targets.encode(target) if target is not None else None + + pred_score = float(prediction.detection_score) if prediction else 0 + + class_scores = ( + { + str(class_name): float(score) + for class_name, score in iterate_over_array( + prediction.class_scores + ) + } + if prediction is not None + else {} + ) + + matches.append( + Match( + gt_uuid=gt_uuid, + gt_det=gt_det, + gt_class=gt_class, + pred_score=pred_score, + affinity=affinity, + class_scores=class_scores, + ) + ) + + return matches + + +def match_predictions_and_annotations( + clip_annotation: data.ClipAnnotation, + clip_prediction: data.ClipPrediction, +) -> List[data.Match]: + annotated_sound_events = [ + sound_event_annotation + for sound_event_annotation in clip_annotation.sound_events + if sound_event_annotation.sound_event.geometry is not None + ] + + predicted_sound_events = [ + sound_event_prediction + for sound_event_prediction in clip_prediction.sound_events + if sound_event_prediction.sound_event.geometry is not None + ] + + annotated_geometries: List[data.Geometry] = [ + sound_event.sound_event.geometry + for sound_event in annotated_sound_events + if sound_event.sound_event.geometry is not None + ] + + predicted_geometries: List[data.Geometry] = [ + sound_event.sound_event.geometry + for sound_event in predicted_sound_events + if sound_event.sound_event.geometry is not None + ] + + matches = [] + for id1, id2, affinity in match_geometries( + annotated_geometries, + predicted_geometries, + ): + target = annotated_sound_events[id1] if id1 is not None else None + source = predicted_sound_events[id2] if id2 is not None else None + matches.append( + data.Match(source=source, target=target, affinity=affinity) + ) + + return matches diff --git a/batdetect2/evaluate/metrics.py b/batdetect2/evaluate/metrics.py new file mode 100644 index 0000000..3cb762e --- /dev/null +++ b/batdetect2/evaluate/metrics.py @@ -0,0 +1,39 @@ +from typing import List + +import pandas as pd +from sklearn import metrics +from sklearn.preprocessing import label_binarize + +from batdetect2.evaluate.types import Match, MetricsProtocol + +__all__ = ["DetectionAveragePrecision"] + + +class DetectionAveragePrecision(MetricsProtocol): + name: str = "detection/average_precision" + + def __call__(self, matches: List[Match]) -> float: + y_true, y_score = zip( + *[(match.gt_det, match.pred_score) for match in matches] + ) + return float(metrics.average_precision_score(y_true, y_score)) + + +class ClassificationMeanAveragePrecision(MetricsProtocol): + name: str = "classification/average_precision" + + def __init__(self, class_names: List[str]): + self.class_names = class_names + + def __call__(self, matches: List[Match]) -> float: + y_true = label_binarize( + [ + match.gt_class if match.gt_class is not None else "__NONE__" + for match in matches + ], + classes=self.class_names, + ) + y_pred = pd.DataFrame([match.class_scores for match in matches]) + return float( + metrics.average_precision_score(y_true, y_pred[self.class_names]) + ) diff --git a/batdetect2/evaluate/types.py b/batdetect2/evaluate/types.py new file mode 100644 index 0000000..081253f --- /dev/null +++ b/batdetect2/evaluate/types.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol +from uuid import UUID + +__all__ = [ + "MetricsProtocol", + "Match", +] + + +@dataclass +class Match: + gt_uuid: Optional[UUID] + gt_det: bool + gt_class: Optional[str] + pred_score: float + affinity: float + class_scores: Dict[str, float] + + +class MetricsProtocol(Protocol): + name: str + + def __call__(self, matches: List[Match]) -> float: ... diff --git a/batdetect2/postprocess/__init__.py b/batdetect2/postprocess/__init__.py index b9050ad..cc9295a 100644 --- a/batdetect2/postprocess/__init__.py +++ b/batdetect2/postprocess/__init__.py @@ -170,8 +170,8 @@ def load_postprocess_config( def build_postprocessor( targets: TargetProtocol, config: Optional[PostprocessConfig] = None, - max_freq: int = MAX_FREQ, - min_freq: int = MIN_FREQ, + max_freq: float = MAX_FREQ, + min_freq: float = MIN_FREQ, ) -> PostprocessorProtocol: """Factory function to build the standard postprocessor. @@ -234,9 +234,9 @@ class Postprocessor(PostprocessorProtocol): recovery. config : PostprocessConfig Configuration object holding parameters for NMS, thresholds, etc. - min_freq : int + min_freq : float Minimum frequency (Hz) assumed for the model output's frequency axis. - max_freq : int + max_freq : float Maximum frequency (Hz) assumed for the model output's frequency axis. """ @@ -246,8 +246,8 @@ class Postprocessor(PostprocessorProtocol): self, targets: TargetProtocol, config: PostprocessConfig, - min_freq: int = MIN_FREQ, - max_freq: int = MAX_FREQ, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, ): """Initialize the Postprocessor. diff --git a/batdetect2/postprocess/decoding.py b/batdetect2/postprocess/decoding.py index 4e62402..2ba6e45 100644 --- a/batdetect2/postprocess/decoding.py +++ b/batdetect2/postprocess/decoding.py @@ -32,10 +32,10 @@ from typing import List, Optional import numpy as np import xarray as xr from soundevent import data -from soundevent.geometry import compute_bounds from batdetect2.postprocess.types import GeometryBuilder, RawPrediction from batdetect2.targets.classes import SoundEventDecoder +from batdetect2.utils.arrays import iterate_over_array __all__ = [ "convert_xr_dataset_to_raw_prediction", @@ -97,18 +97,14 @@ def convert_xr_dataset_to_raw_prediction( det_info = detection_dataset.sel(detection=det_num) geom = geometry_builder( - (det_info.time, det_info.freq), + (det_info.time, det_info.frequency), det_info.dimensions, ) - start_time, low_freq, end_time, high_freq = compute_bounds(geom) detections.append( RawPrediction( - detection_score=det_info.score, - start_time=start_time, - end_time=end_time, - low_freq=low_freq, - high_freq=high_freq, + detection_score=det_info.scores, + geometry=geom, class_scores=det_info.classes, features=det_info.features, ) @@ -244,14 +240,7 @@ def convert_raw_prediction_to_sound_event_prediction( """ sound_event = data.SoundEvent( recording=recording, - geometry=data.BoundingBox( - coordinates=[ - raw_prediction.start_time, - raw_prediction.low_freq, - raw_prediction.end_time, - raw_prediction.high_freq, - ] - ), + geometry=raw_prediction.geometry, features=get_prediction_features(raw_prediction.features), ) @@ -333,7 +322,7 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]: ), value=value, ) - for feat_name, value in _iterate_over_array(features) + for feat_name, value in iterate_over_array(features) ] @@ -394,13 +383,6 @@ def get_class_tags( return tags -def _iterate_over_array(array: xr.DataArray): - dim_name = array.dims[0] - coords = array.coords[dim_name] - for value, coord in zip(array.values, coords.values): - yield coord, float(value) - - def _iterate_sorted(array: xr.DataArray): dim_name = array.dims[0] coords = array.coords[dim_name].values diff --git a/batdetect2/postprocess/types.py b/batdetect2/postprocess/types.py index 1269a21..70f1f39 100644 --- a/batdetect2/postprocess/types.py +++ b/batdetect2/postprocess/types.py @@ -47,14 +47,9 @@ class RawPrediction(NamedTuple): Attributes ---------- - start_time : float - Start time of the recovered bounding box in seconds. - end_time : float - End time of the recovered bounding box in seconds. - low_freq : float - Lowest frequency of the recovered bounding box in Hz. - high_freq : float - Highest frequency of the recovered bounding box in Hz. + geometry: data.Geometry + The recovered estimated geometry of the detected sound event. + Usually a bounding box. detection_score : float The confidence score associated with this detection, typically from the detection heatmap peak. @@ -67,10 +62,7 @@ class RawPrediction(NamedTuple): detection location. Indexed by a 'feature' coordinate. """ - start_time: float - end_time: float - low_freq: float - high_freq: float + geometry: data.Geometry detection_score: float class_scores: xr.DataArray features: xr.DataArray diff --git a/batdetect2/targets/filtering.py b/batdetect2/targets/filtering.py index 8869050..a1172d0 100644 --- a/batdetect2/targets/filtering.py +++ b/batdetect2/targets/filtering.py @@ -106,7 +106,7 @@ def contains_tags( False otherwise. """ sound_event_tags = set(sound_event_annotation.tags) - return tags < sound_event_tags + return tags <= sound_event_tags def does_not_have_tags( diff --git a/batdetect2/targets/rois.py b/batdetect2/targets/rois.py index 356004f..1a17949 100644 --- a/batdetect2/targets/rois.py +++ b/batdetect2/targets/rois.py @@ -20,14 +20,27 @@ scaling factors) is managed by the `ROIConfig`. This module separates the handled in `batdetect2.targets.classes`. """ -from typing import List, Optional, Protocol, Tuple +from typing import List, Literal, Optional, Protocol, Tuple import numpy as np -from soundevent import data, geometry -from soundevent.geometry.operations import Positions +from soundevent import data from batdetect2.configs import BaseConfig, load_config +Positions = Literal[ + "bottom-left", + "bottom-right", + "top-left", + "top-right", + "center-left", + "center-right", + "top-center", + "bottom-center", + "center", + "centroid", + "point_on_surface", +] + __all__ = [ "ROITargetMapper", "ROIConfig", @@ -242,6 +255,8 @@ class BBoxEncoder(ROITargetMapper): Tuple[float, float] Reference position (time, frequency). """ + from soundevent import geometry + return geometry.get_geometry_point(geom, position=self.position) def get_roi_size(self, geom: data.Geometry) -> np.ndarray: @@ -260,6 +275,8 @@ class BBoxEncoder(ROITargetMapper): np.ndarray A 1D NumPy array: `[scaled_width, scaled_height]`. """ + from soundevent import geometry + start_time, low_freq, end_time, high_freq = geometry.compute_bounds( geom ) @@ -308,8 +325,8 @@ class BBoxEncoder(ROITargetMapper): width, height = dims return _build_bounding_box( pos, - duration=width / self.time_scale, - bandwidth=height / self.frequency_scale, + duration=float(width) / self.time_scale, + bandwidth=float(height) / self.frequency_scale, position=self.position, ) @@ -421,14 +438,16 @@ def _build_bounding_box( ValueError If `position` is not a recognized value or format. """ - time, freq = pos + time, freq = map(float, pos) + duration = max(0, duration) + bandwidth = max(0, bandwidth) if position in ["center", "centroid", "point_on_surface"]: return data.BoundingBox( coordinates=[ - time - duration / 2, - freq - bandwidth / 2, - time + duration / 2, - freq + bandwidth / 2, + max(time - duration / 2, 0), + max(freq - bandwidth / 2, 0), + max(time + duration / 2, 0), + max(freq + bandwidth / 2, 0), ] ) @@ -454,9 +473,9 @@ def _build_bounding_box( return data.BoundingBox( coordinates=[ - start_time, - low_freq, - start_time + duration, - low_freq + bandwidth, + max(0, start_time), + max(0, low_freq), + max(0, start_time + duration), + max(0, low_freq + bandwidth), ] ) diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index 692de6a..f15b87c 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -14,20 +14,28 @@ from batdetect2.train.augmentations import ( warp_spectrogram, ) from batdetect2.train.clips import build_clipper, select_subclip -from batdetect2.train.config import TrainingConfig, load_train_config +from batdetect2.train.config import ( + TrainerConfig, + TrainingConfig, + load_train_config, +) from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, TrainExample, list_preprocessed_files, ) -from batdetect2.train.labels import load_label_config +from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.losses import LossFunction, build_loss from batdetect2.train.preprocess import ( generate_train_example, preprocess_annotations, ) -from batdetect2.train.train import TrainerConfig, load_trainer_config, train +from batdetect2.train.train import ( + build_train_dataset, + build_val_dataset, + train, +) __all__ = [ "AugmentationsConfig", @@ -44,13 +52,15 @@ __all__ = [ "WarpAugmentationConfig", "add_echo", "build_augmentations", + "build_clip_labeler", "build_clipper", "build_loss", + "build_train_dataset", + "build_val_dataset", "generate_train_example", "list_preprocessed_files", "load_label_config", "load_train_config", - "load_trainer_config", "mask_frequency", "mask_time", "mix_examples", @@ -58,5 +68,6 @@ __all__ = [ "scale_volume", "select_subclip", "train", + "train", "warp_spectrogram", ] diff --git a/batdetect2/train/callbacks.py b/batdetect2/train/callbacks.py index cb486b2..f1e0895 100644 --- a/batdetect2/train/callbacks.py +++ b/batdetect2/train/callbacks.py @@ -1,30 +1,51 @@ +from typing import List + from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback +from soundevent import data from torch.utils.data import DataLoader -from batdetect2.postprocess import PostprocessorProtocol +from batdetect2.evaluate.match import match_sound_events_and_raw_predictions +from batdetect2.evaluate.types import Match, MetricsProtocol +from batdetect2.targets.types import TargetProtocol from batdetect2.train.dataset import LabeledDataset, TrainExample -from batdetect2.types import ModelOutput +from batdetect2.train.lightning import TrainingModule +from batdetect2.train.types import ModelOutput class ValidationMetrics(Callback): - def __init__(self, postprocessor: PostprocessorProtocol): + def __init__(self, metrics: List[MetricsProtocol]): super().__init__() - self.postprocessor = postprocessor - self.predictions = [] + + if len(metrics) == 0: + raise ValueError("At least one metric needs to be provided") + + self.matches: List[Match] = [] + self.metrics = metrics + + def on_validation_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + ) -> None: + for metric in self.metrics: + value = metric(self.matches) + pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True) + + return super().on_validation_epoch_end(trainer, pl_module) def on_validation_epoch_start( self, trainer: Trainer, pl_module: LightningModule, ) -> None: - self.predictions = [] + self.matches = [] return super().on_validation_epoch_start(trainer, pl_module) def on_validation_batch_end( # type: ignore self, trainer: Trainer, - pl_module: LightningModule, + pl_module: TrainingModule, outputs: ModelOutput, batch: TrainExample, batch_idx: int, @@ -32,24 +53,73 @@ class ValidationMetrics(Callback): ) -> None: dataloaders = trainer.val_dataloaders assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset assert isinstance(dataset, LabeledDataset) - clip_annotation = dataset.get_clip_annotation(batch_idx) - # clip_prediction = postprocess_model_outputs( - # outputs, - # clips=[clip_annotation.clip], - # classes=self.class_names, - # decoder=self.decoder, - # config=self.config.postprocessing, - # )[0] - # - # matches = match_predictions_and_annotations( - # clip_annotation, - # clip_prediction, - # ) - # - # self.validation_predictions.extend(matches) - # return super().on_validation_batch_end( - # trainer, pl_module, outputs, batch, batch_idx, dataloader_idx - # ) + clip_annotations = [ + _get_subclip( + dataset.get_clip_annotation(example_id), + start_time=start_time.item(), + end_time=end_time.item(), + targets=pl_module.targets, + ) + for example_id, start_time, end_time in zip( + batch.idx, + batch.start_time, + batch.end_time, + ) + ] + + clips = [clip_annotation.clip for clip_annotation in clip_annotations] + + raw_predictions = pl_module.postprocessor.get_raw_predictions( + outputs, + clips, + ) + + for clip_annotation, clip_predictions in zip( + clip_annotations, raw_predictions + ): + self.matches.extend( + match_sound_events_and_raw_predictions( + sound_events=clip_annotation.sound_events, + raw_predictions=clip_predictions, + targets=pl_module.targets, + ) + ) + + +def _is_in_subclip( + sound_event_annotation: data.SoundEventAnnotation, + targets: TargetProtocol, + start_time: float, + end_time: float, +) -> bool: + time, _ = targets.get_position(sound_event_annotation) + return start_time <= time <= end_time + + +def _get_subclip( + clip_annotation: data.ClipAnnotation, + start_time: float, + end_time: float, + targets: TargetProtocol, +) -> data.ClipAnnotation: + return data.ClipAnnotation( + clip=data.Clip( + recording=clip_annotation.clip.recording, + start_time=start_time, + end_time=end_time, + ), + sound_events=[ + sound_event_annotation + for sound_event_annotation in clip_annotation.sound_events + if _is_in_subclip( + sound_event_annotation, + targets, + start_time=start_time, + end_time=end_time, + ) + ], + ) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 3426e9c..f64c8f9 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -42,8 +42,8 @@ class LabeledDataset(Dataset): class_heatmap=self.to_tensor(dataset["class"]), size_heatmap=self.to_tensor(dataset["size"]), idx=torch.tensor(idx), - start_time=start_time, - end_time=end_time, + start_time=torch.tensor(start_time), + end_time=torch.tensor(end_time), ) @classmethod diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index 11955e4..c0acf45 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -23,13 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`) defined in `LabelConfig`. """ -import logging from collections.abc import Iterable from functools import partial from typing import Optional import numpy as np import xarray as xr +from loguru import logger from scipy.ndimage import gaussian_filter from soundevent import arrays, data @@ -52,8 +52,6 @@ __all__ = [ SIZE_DIMENSION = "dimension" """Dimension name for the size heatmap.""" -logger = logging.getLogger(__name__) - class LabelConfig(BaseConfig): """Configuration parameters for heatmap generation. @@ -137,12 +135,27 @@ def generate_clip_label( A NamedTuple containing the generated 'detection', 'classes', and 'size' heatmaps for this clip. """ + logger.debug( + "Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", + uuid=clip_annotation.uuid, + num=len(clip_annotation.sound_events) + ) + + sound_events = [] + + for sound_event_annotation in clip_annotation.sound_events: + if not targets.filter(sound_event_annotation): + logger.debug( + "Sound event {sound_event} did not pass the filter. Tags: {tags}", + sound_event=sound_event_annotation, + tags=sound_event_annotation.tags, + ) + continue + + sound_events.append(targets.transform(sound_event_annotation)) + return generate_heatmaps( - ( - targets.transform(sound_event_annotation) - for sound_event_annotation in clip_annotation.sound_events - if targets.filter(sound_event_annotation) - ), + sound_events, spec=spec, targets=targets, target_sigma=config.sigma, diff --git a/batdetect2/train/lightning.py b/batdetect2/train/lightning.py index ffdebb2..515f374 100644 --- a/batdetect2/train/lightning.py +++ b/batdetect2/train/lightning.py @@ -58,7 +58,9 @@ class TrainingModule(L.LightningModule): return losses.total - def validation_step(self, batch: TrainExample, batch_idx: int) -> None: + def validation_step( # type: ignore + self, batch: TrainExample, batch_idx: int + ) -> ModelOutput: outputs = self.forward(batch.spec) losses = self.loss(outputs, batch) @@ -67,6 +69,8 @@ class TrainingModule(L.LightningModule): self.log("val/loss/size", losses.total, logger=True) self.log("val/loss/classification", losses.total, logger=True) + return outputs + def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.learning_rate) scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index 039bc32..e4611ba 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -1,12 +1,16 @@ from typing import List, Optional from lightning import Trainer +from lightning.pytorch.callbacks import Callback from soundevent import data from torch.utils.data import DataLoader from batdetect2.models.types import DetectionModel +from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess.types import PostprocessorProtocol +from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets import build_targets from batdetect2.targets.types import TargetProtocol from batdetect2.train.augmentations import ( build_augmentations, @@ -19,25 +23,40 @@ from batdetect2.train.losses import build_loss __all__ = [ "train", + "build_val_dataset", + "build_train_dataset", ] def train( detector: DetectionModel, - targets: TargetProtocol, - preprocessor: PreprocessorProtocol, - postprocessor: PostprocessorProtocol, train_examples: List[data.PathLike], + targets: Optional[TargetProtocol] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + postprocessor: Optional[PostprocessorProtocol] = None, val_examples: Optional[List[data.PathLike]] = None, config: Optional[TrainingConfig] = None, + callbacks: Optional[List[Callback]] = None, ) -> None: config = config or TrainingConfig() - train_dataset = build_dataset( + if preprocessor is None: + preprocessor = build_preprocessor() + + if targets is None: + targets = build_targets() + + if postprocessor is None: + postprocessor = build_postprocessor( + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + + train_dataset = build_train_dataset( train_examples, preprocessor, config=config, - train=True, ) loss = build_loss(config.loss) @@ -52,7 +71,13 @@ def train( t_max=config.optimizer.t_max, ) - trainer = Trainer(**config.trainer.model_dump()) + trainer = Trainer( + **config.trainer.model_dump(exclude_none=True), + callbacks=callbacks, + num_sanity_val_steps=0, + # enable_model_summary=False, + # enable_progress_bar=False, + ) train_dataloader = DataLoader( train_dataset, @@ -62,11 +87,9 @@ def train( val_dataloader = None if val_examples: - val_dataset = build_dataset( + val_dataset = build_val_dataset( val_examples, - preprocessor, config=config, - train=False, ) val_dataloader = DataLoader( val_dataset, @@ -81,32 +104,38 @@ def train( ) -def build_dataset( +def build_train_dataset( examples: List[data.PathLike], preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, - train: bool = True, -): +) -> LabeledDataset: config = config or TrainingConfig() - clipper = build_clipper(config.cliping, random=train) + clipper = build_clipper(config.cliping, random=True) - augmentations = None + random_example_source = RandomExampleSource( + examples, + clipper=clipper, + ) - if train: - random_example_source = RandomExampleSource( - examples, - clipper=clipper, - ) - - augmentations = build_augmentations( - preprocessor, - config=config.augmentations, - example_source=random_example_source, - ) + augmentations = build_augmentations( + preprocessor, + config=config.augmentations, + example_source=random_example_source, + ) return LabeledDataset( examples, clipper=clipper, augmentation=augmentations, ) + + +def build_val_dataset( + examples: List[data.PathLike], + config: Optional[TrainingConfig] = None, + train: bool = True, +) -> LabeledDataset: + config = config or TrainingConfig() + clipper = build_clipper(config.cliping, random=train) + return LabeledDataset(examples, clipper=clipper) diff --git a/batdetect2/train/types.py b/batdetect2/train/types.py index 1be4e3e..02e84c7 100644 --- a/batdetect2/train/types.py +++ b/batdetect2/train/types.py @@ -57,8 +57,8 @@ class TrainExample(NamedTuple): class_heatmap: torch.Tensor size_heatmap: torch.Tensor idx: torch.Tensor - start_time: float - end_time: float + start_time: torch.Tensor + end_time: torch.Tensor class Losses(NamedTuple): diff --git a/batdetect2/utils/arrays.py b/batdetect2/utils/arrays.py index e88a516..bf00ee7 100644 --- a/batdetect2/utils/arrays.py +++ b/batdetect2/utils/arrays.py @@ -1,4 +1,5 @@ import numpy as np +import xarray as xr def extend_width( @@ -59,3 +60,10 @@ def adjust_width( for index in range(dims) ] return array[tuple(slices)] + + +def iterate_over_array(array: xr.DataArray): + dim_name = array.dims[0] + coords = array.coords[dim_name] + for value, coord in zip(array.values, coords.values): + yield coord, float(value)