fix: resolve remaining type check issues

This commit is contained in:
mbsantiago 2026-05-06 17:43:29 +01:00
parent ce6975770e
commit b0f85b96e3
11 changed files with 71 additions and 48 deletions

View File

@ -102,19 +102,19 @@ def convert_to_annotation_group(
x_inds.append(0)
y_inds.append(0)
annotations.append(
Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=1.0,
det_prob=1.0,
individual="0",
event=event,
class_id=class_id,
)
)
annotation_entry: Annotation = {
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class": get_recording_class_name(recording),
"class_id": class_id,
}
annotations.append(annotation_entry)
return {
"id": str(recording.path),

View File

@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str, discriminator: str = "name"):
self._name = name
self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type]
str, Callable[Concatenate[Any, P_Type], T_Type]
] = {}
self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {}
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
)
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
func: Callable[..., T_Type],
):
self._registry[name] = func
return func
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
def build(
self,
config: BaseModel,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
*args: Any,
**kwargs: Any,
) -> T_Type:
"""Builds a logic instance from a config object."""

View File

@ -12,13 +12,15 @@ __all__ = [
]
def _default_tasks() -> list[TaskConfig]:
return [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
class EvaluationConfig(BaseConfig):
tasks: List[TaskConfig] = Field(
default_factory=lambda: [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
)
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
def get_default_eval_config() -> EvaluationConfig:

View File

@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClassificationMetricConfig]:
return [ClassificationAveragePrecisionConfig()]
class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True

View File

@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipClassificationMetricConfig]:
return [ClipClassificationAveragePrecisionConfig()]
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
default_factory=_default_metrics
)
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)

View File

@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipDetectionMetricConfig]:
return [ClipDetectionAveragePrecisionConfig()]
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
default_factory=_default_metrics
)
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[DetectionMetricConfig]:
return [DetectionAveragePrecisionConfig()]
class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[DetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[TopClassMetricConfig]:
return [TopClassAveragePrecisionConfig()]
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[TopClassPlotConfig] = Field(default_factory=list)

View File

@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
top_class_index = int(np.argmax(prediction.class_scores))
top_class_score = float(prediction.class_scores[top_class_index])
top_class = self.get_class_name(top_class_index)
return Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=top_class_score,
det_prob=float(prediction.detection_score),
individual="",
event=self.event_name,
**{"class": top_class},
)
annotation: Annotation = {
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": top_class_score,
"det_prob": float(prediction.detection_score),
"individual": "",
"event": self.event_name,
"class": top_class,
}
return annotation
@output_formatters.register(BatDetect2OutputConfig)
@staticmethod

View File

@ -26,6 +26,13 @@ __all__ = [
]
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
return [
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
audio_transforms: List[AudioTransform] = Field(default_factory=list)
spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
default_factory=_default_spectrogram_transforms
)
stft: STFTConfig = Field(default_factory=STFTConfig)

View File

@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig( # ty: ignore[unknown-argument]
match_if=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig(