diff --git a/src/batdetect2/compat/data.py b/src/batdetect2/compat/data.py index 8803e31..48c15ca 100644 --- a/src/batdetect2/compat/data.py +++ b/src/batdetect2/compat/data.py @@ -102,19 +102,19 @@ def convert_to_annotation_group( x_inds.append(0) y_inds.append(0) - annotations.append( - Annotation( - start_time=start_time, - end_time=end_time, - low_freq=low_freq, - high_freq=high_freq, - class_prob=1.0, - det_prob=1.0, - individual="0", - event=event, - class_id=class_id, - ) - ) + annotation_entry: Annotation = { + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + "class_prob": 1.0, + "det_prob": 1.0, + "individual": "0", + "event": event, + "class": get_recording_class_name(recording), + "class_id": class_id, + } + annotations.append(annotation_entry) return { "id": str(recording.path), diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index c1538e1..a998c1a 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]): def __init__(self, name: str, discriminator: str = "name"): self._name = name self._registry: dict[ - str, Callable[Concatenate[..., P_Type], T_Type] + str, Callable[Concatenate[Any, P_Type], T_Type] ] = {} self._discriminator = discriminator self._config_types: dict[str, Type[BaseModel]] = {} @@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]): ) def decorator( - func: Callable[Concatenate[T_Config, P_Type], T_Type], + func: Callable[..., T_Type], ): self._registry[name] = func return func @@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]): def build( self, config: BaseModel, - *args: P_Type.args, - **kwargs: P_Type.kwargs, + *args: Any, + **kwargs: Any, ) -> T_Type: """Builds a logic instance from a config object.""" diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index bc49a26..ebeb8a7 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -12,13 +12,15 @@ __all__ = [ ] +def _default_tasks() -> list[TaskConfig]: + return [ + DetectionTaskConfig(), + ClassificationTaskConfig(), + ] + + class EvaluationConfig(BaseConfig): - tasks: List[TaskConfig] = Field( - default_factory=lambda: [ - DetectionTaskConfig(), - ClassificationTaskConfig(), - ] - ) + tasks: List[TaskConfig] = Field(default_factory=_default_tasks) def get_default_eval_config() -> EvaluationConfig: diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 43977bb..5dbbeaa 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.targets.types import TargetProtocol +def _default_metrics() -> list[ClassificationMetricConfig]: + return [ClassificationAveragePrecisionConfig()] + + class ClassificationTaskConfig(BaseSEDTaskConfig): name: Literal["sound_event_classification"] = "sound_event_classification" prefix: str = "classification" metrics: list[ClassificationMetricConfig] = Field( - default_factory=lambda: [ClassificationAveragePrecisionConfig()] + default_factory=_default_metrics ) plots: list[ClassificationPlotConfig] = Field(default_factory=list) include_generics: bool = True diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py index 958d279..de76cde 100644 --- a/src/batdetect2/evaluate/tasks/clip_classification.py +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections from batdetect2.targets.types import TargetProtocol +def _default_metrics() -> list[ClipClassificationMetricConfig]: + return [ClipClassificationAveragePrecisionConfig()] + + class ClipClassificationTaskConfig(BaseTaskConfig): name: Literal["clip_classification"] = "clip_classification" prefix: str = "clip_classification" metrics: list[ClipClassificationMetricConfig] = Field( - default_factory=lambda: [ - ClipClassificationAveragePrecisionConfig(), - ] + default_factory=_default_metrics ) plots: list[ClipClassificationPlotConfig] = Field(default_factory=list) diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py index ed810d5..09ddf44 100644 --- a/src/batdetect2/evaluate/tasks/clip_detection.py +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections from batdetect2.targets.types import TargetProtocol +def _default_metrics() -> list[ClipDetectionMetricConfig]: + return [ClipDetectionAveragePrecisionConfig()] + + class ClipDetectionTaskConfig(BaseTaskConfig): name: Literal["clip_detection"] = "clip_detection" prefix: str = "clip_detection" metrics: list[ClipDetectionMetricConfig] = Field( - default_factory=lambda: [ - ClipDetectionAveragePrecisionConfig(), - ] + default_factory=_default_metrics ) plots: list[ClipDetectionPlotConfig] = Field(default_factory=list) diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 49099c5..480abd6 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections from batdetect2.targets.types import TargetProtocol +def _default_metrics() -> list[DetectionMetricConfig]: + return [DetectionAveragePrecisionConfig()] + + class DetectionTaskConfig(BaseSEDTaskConfig): name: Literal["sound_event_detection"] = "sound_event_detection" prefix: str = "detection" metrics: list[DetectionMetricConfig] = Field( - default_factory=lambda: [DetectionAveragePrecisionConfig()] + default_factory=_default_metrics ) plots: list[DetectionPlotConfig] = Field(default_factory=list) diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index 337ee6a..d37d9a9 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections from batdetect2.targets.types import TargetProtocol +def _default_metrics() -> list[TopClassMetricConfig]: + return [TopClassAveragePrecisionConfig()] + + class TopClassDetectionTaskConfig(BaseSEDTaskConfig): name: Literal["top_class_detection"] = "top_class_detection" prefix: str = "top_class" metrics: list[TopClassMetricConfig] = Field( - default_factory=lambda: [TopClassAveragePrecisionConfig()] + default_factory=_default_metrics ) plots: list[TopClassPlotConfig] = Field(default_factory=list) diff --git a/src/batdetect2/outputs/formats/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py index 4e1433e..115a908 100644 --- a/src/batdetect2/outputs/formats/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): top_class_index = int(np.argmax(prediction.class_scores)) top_class_score = float(prediction.class_scores[top_class_index]) top_class = self.get_class_name(top_class_index) - return Annotation( - start_time=start_time, - end_time=end_time, - low_freq=low_freq, - high_freq=high_freq, - class_prob=top_class_score, - det_prob=float(prediction.detection_score), - individual="", - event=self.event_name, - **{"class": top_class}, - ) + annotation: Annotation = { + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + "class_prob": top_class_score, + "det_prob": float(prediction.detection_score), + "individual": "", + "event": self.event_name, + "class": top_class, + } + return annotation @output_formatters.register(BatDetect2OutputConfig) @staticmethod diff --git a/src/batdetect2/preprocess/config.py b/src/batdetect2/preprocess/config.py index d8b1d17..5ebb099 100644 --- a/src/batdetect2/preprocess/config.py +++ b/src/batdetect2/preprocess/config.py @@ -26,6 +26,13 @@ __all__ = [ ] +def _default_spectrogram_transforms() -> list[SpectrogramTransform]: + return [ + PcenConfig(), + SpectralMeanSubtractionConfig(), + ] + + class PreprocessingConfig(BaseConfig): """Unified configuration for the audio preprocessing pipeline. @@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig): audio_transforms: List[AudioTransform] = Field(default_factory=list) spectrogram_transforms: List[SpectrogramTransform] = Field( - default_factory=lambda: [ - PcenConfig(), - SpectralMeanSubtractionConfig(), - ] + default_factory=_default_spectrogram_transforms ) stft: STFTConfig = Field(default_factory=STFTConfig) diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index ee6e0a3..78c66a8 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig): DEFAULT_DETECTION_CLASS = TargetClassConfig( name="bat", - match_if=AllOfConfig( # ty: ignore[unknown-argument] + match_if=AllOfConfig( conditions=[ HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")), NotConfig(