mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: resolve remaining type check issues
This commit is contained in:
parent
ce6975770e
commit
b0f85b96e3
@ -102,19 +102,19 @@ def convert_to_annotation_group(
|
||||
x_inds.append(0)
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=1.0,
|
||||
det_prob=1.0,
|
||||
individual="0",
|
||||
event=event,
|
||||
class_id=class_id,
|
||||
)
|
||||
)
|
||||
annotation_entry: Annotation = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class": get_recording_class_name(recording),
|
||||
"class_id": class_id,
|
||||
}
|
||||
annotations.append(annotation_entry)
|
||||
|
||||
return {
|
||||
"id": str(recording.path),
|
||||
|
||||
@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
def __init__(self, name: str, discriminator: str = "name"):
|
||||
self._name = name
|
||||
self._registry: dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
str, Callable[Concatenate[Any, P_Type], T_Type]
|
||||
] = {}
|
||||
self._discriminator = discriminator
|
||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
func: Callable[..., T_Type],
|
||||
):
|
||||
self._registry[name] = func
|
||||
return func
|
||||
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
def build(
|
||||
self,
|
||||
config: BaseModel,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
|
||||
@ -12,13 +12,15 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _default_tasks() -> list[TaskConfig]:
|
||||
return [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
tasks: List[TaskConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
|
||||
|
||||
|
||||
def get_default_eval_config() -> EvaluationConfig:
|
||||
|
||||
@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClassificationMetricConfig]:
|
||||
return [ClassificationAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: list[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClipClassificationMetricConfig]:
|
||||
return [ClipClassificationAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClipDetectionMetricConfig]:
|
||||
return [ClipDetectionAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[DetectionMetricConfig]:
|
||||
return [DetectionAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: list[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[TopClassMetricConfig]:
|
||||
return [TopClassAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: list[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
top_class_index = int(np.argmax(prediction.class_scores))
|
||||
top_class_score = float(prediction.class_scores[top_class_index])
|
||||
top_class = self.get_class_name(top_class_index)
|
||||
return Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=top_class_score,
|
||||
det_prob=float(prediction.detection_score),
|
||||
individual="",
|
||||
event=self.event_name,
|
||||
**{"class": top_class},
|
||||
)
|
||||
annotation: Annotation = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": top_class_score,
|
||||
"det_prob": float(prediction.detection_score),
|
||||
"individual": "",
|
||||
"event": self.event_name,
|
||||
"class": top_class,
|
||||
}
|
||||
return annotation
|
||||
|
||||
@output_formatters.register(BatDetect2OutputConfig)
|
||||
@staticmethod
|
||||
|
||||
@ -26,6 +26,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
|
||||
return [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
|
||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
default_factory=_default_spectrogram_transforms
|
||||
)
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
|
||||
@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
|
||||
|
||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||
name="bat",
|
||||
match_if=AllOfConfig( # ty: ignore[unknown-argument]
|
||||
match_if=AllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||
NotConfig(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user