fix: resolve remaining type check issues

This commit is contained in:
mbsantiago 2026-05-06 17:43:29 +01:00
parent ce6975770e
commit b0f85b96e3
11 changed files with 71 additions and 48 deletions

View File

@ -102,19 +102,19 @@ def convert_to_annotation_group(
x_inds.append(0) x_inds.append(0)
y_inds.append(0) y_inds.append(0)
annotations.append( annotation_entry: Annotation = {
Annotation( "start_time": start_time,
start_time=start_time, "end_time": end_time,
end_time=end_time, "low_freq": low_freq,
low_freq=low_freq, "high_freq": high_freq,
high_freq=high_freq, "class_prob": 1.0,
class_prob=1.0, "det_prob": 1.0,
det_prob=1.0, "individual": "0",
individual="0", "event": event,
event=event, "class": get_recording_class_name(recording),
class_id=class_id, "class_id": class_id,
) }
) annotations.append(annotation_entry)
return { return {
"id": str(recording.path), "id": str(recording.path),

View File

@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str, discriminator: str = "name"): def __init__(self, name: str, discriminator: str = "name"):
self._name = name self._name = name
self._registry: dict[ self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type] str, Callable[Concatenate[Any, P_Type], T_Type]
] = {} ] = {}
self._discriminator = discriminator self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {} self._config_types: dict[str, Type[BaseModel]] = {}
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
) )
def decorator( def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type], func: Callable[..., T_Type],
): ):
self._registry[name] = func self._registry[name] = func
return func return func
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
def build( def build(
self, self,
config: BaseModel, config: BaseModel,
*args: P_Type.args, *args: Any,
**kwargs: P_Type.kwargs, **kwargs: Any,
) -> T_Type: ) -> T_Type:
"""Builds a logic instance from a config object.""" """Builds a logic instance from a config object."""

View File

@ -12,13 +12,15 @@ __all__ = [
] ]
def _default_tasks() -> list[TaskConfig]:
return [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
tasks: List[TaskConfig] = Field( tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
default_factory=lambda: [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
)
def get_default_eval_config() -> EvaluationConfig: def get_default_eval_config() -> EvaluationConfig:

View File

@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClassificationMetricConfig]:
return [ClassificationAveragePrecisionConfig()]
class ClassificationTaskConfig(BaseSEDTaskConfig): class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification" name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification" prefix: str = "classification"
metrics: list[ClassificationMetricConfig] = Field( metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()] default_factory=_default_metrics
) )
plots: list[ClassificationPlotConfig] = Field(default_factory=list) plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True include_generics: bool = True

View File

@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipClassificationMetricConfig]:
return [ClipClassificationAveragePrecisionConfig()]
class ClipClassificationTaskConfig(BaseTaskConfig): class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification" name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification" prefix: str = "clip_classification"
metrics: list[ClipClassificationMetricConfig] = Field( metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [ default_factory=_default_metrics
ClipClassificationAveragePrecisionConfig(),
]
) )
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list) plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)

View File

@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipDetectionMetricConfig]:
return [ClipDetectionAveragePrecisionConfig()]
class ClipDetectionTaskConfig(BaseTaskConfig): class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection" name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection" prefix: str = "clip_detection"
metrics: list[ClipDetectionMetricConfig] = Field( metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [ default_factory=_default_metrics
ClipDetectionAveragePrecisionConfig(),
]
) )
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list) plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[DetectionMetricConfig]:
return [DetectionAveragePrecisionConfig()]
class DetectionTaskConfig(BaseSEDTaskConfig): class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection" name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection" prefix: str = "detection"
metrics: list[DetectionMetricConfig] = Field( metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()] default_factory=_default_metrics
) )
plots: list[DetectionPlotConfig] = Field(default_factory=list) plots: list[DetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[TopClassMetricConfig]:
return [TopClassAveragePrecisionConfig()]
class TopClassDetectionTaskConfig(BaseSEDTaskConfig): class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection" name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class" prefix: str = "top_class"
metrics: list[TopClassMetricConfig] = Field( metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()] default_factory=_default_metrics
) )
plots: list[TopClassPlotConfig] = Field(default_factory=list) plots: list[TopClassPlotConfig] = Field(default_factory=list)

View File

@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
top_class_index = int(np.argmax(prediction.class_scores)) top_class_index = int(np.argmax(prediction.class_scores))
top_class_score = float(prediction.class_scores[top_class_index]) top_class_score = float(prediction.class_scores[top_class_index])
top_class = self.get_class_name(top_class_index) top_class = self.get_class_name(top_class_index)
return Annotation( annotation: Annotation = {
start_time=start_time, "start_time": start_time,
end_time=end_time, "end_time": end_time,
low_freq=low_freq, "low_freq": low_freq,
high_freq=high_freq, "high_freq": high_freq,
class_prob=top_class_score, "class_prob": top_class_score,
det_prob=float(prediction.detection_score), "det_prob": float(prediction.detection_score),
individual="", "individual": "",
event=self.event_name, "event": self.event_name,
**{"class": top_class}, "class": top_class,
) }
return annotation
@output_formatters.register(BatDetect2OutputConfig) @output_formatters.register(BatDetect2OutputConfig)
@staticmethod @staticmethod

View File

@ -26,6 +26,13 @@ __all__ = [
] ]
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
return [
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
class PreprocessingConfig(BaseConfig): class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline. """Unified configuration for the audio preprocessing pipeline.
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
audio_transforms: List[AudioTransform] = Field(default_factory=list) audio_transforms: List[AudioTransform] = Field(default_factory=list)
spectrogram_transforms: List[SpectrogramTransform] = Field( spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [ default_factory=_default_spectrogram_transforms
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
) )
stft: STFTConfig = Field(default_factory=STFTConfig) stft: STFTConfig = Field(default_factory=STFTConfig)

View File

@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
DEFAULT_DETECTION_CLASS = TargetClassConfig( DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat", name="bat",
match_if=AllOfConfig( # ty: ignore[unknown-argument] match_if=AllOfConfig(
conditions=[ conditions=[
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")), HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig( NotConfig(