mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: resolve remaining type check issues
This commit is contained in:
parent
ce6975770e
commit
b0f85b96e3
@ -102,19 +102,19 @@ def convert_to_annotation_group(
|
|||||||
x_inds.append(0)
|
x_inds.append(0)
|
||||||
y_inds.append(0)
|
y_inds.append(0)
|
||||||
|
|
||||||
annotations.append(
|
annotation_entry: Annotation = {
|
||||||
Annotation(
|
"start_time": start_time,
|
||||||
start_time=start_time,
|
"end_time": end_time,
|
||||||
end_time=end_time,
|
"low_freq": low_freq,
|
||||||
low_freq=low_freq,
|
"high_freq": high_freq,
|
||||||
high_freq=high_freq,
|
"class_prob": 1.0,
|
||||||
class_prob=1.0,
|
"det_prob": 1.0,
|
||||||
det_prob=1.0,
|
"individual": "0",
|
||||||
individual="0",
|
"event": event,
|
||||||
event=event,
|
"class": get_recording_class_name(recording),
|
||||||
class_id=class_id,
|
"class_id": class_id,
|
||||||
)
|
}
|
||||||
)
|
annotations.append(annotation_entry)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(recording.path),
|
"id": str(recording.path),
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
def __init__(self, name: str, discriminator: str = "name"):
|
def __init__(self, name: str, discriminator: str = "name"):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry: dict[
|
self._registry: dict[
|
||||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
str, Callable[Concatenate[Any, P_Type], T_Type]
|
||||||
] = {}
|
] = {}
|
||||||
self._discriminator = discriminator
|
self._discriminator = discriminator
|
||||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||||
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
func: Callable[..., T_Type],
|
||||||
):
|
):
|
||||||
self._registry[name] = func
|
self._registry[name] = func
|
||||||
return func
|
return func
|
||||||
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
config: BaseModel,
|
config: BaseModel,
|
||||||
*args: P_Type.args,
|
*args: Any,
|
||||||
**kwargs: P_Type.kwargs,
|
**kwargs: Any,
|
||||||
) -> T_Type:
|
) -> T_Type:
|
||||||
"""Builds a logic instance from a config object."""
|
"""Builds a logic instance from a config object."""
|
||||||
|
|
||||||
|
|||||||
@ -12,13 +12,15 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_tasks() -> list[TaskConfig]:
|
||||||
|
return [
|
||||||
|
DetectionTaskConfig(),
|
||||||
|
ClassificationTaskConfig(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
tasks: List[TaskConfig] = Field(
|
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
|
||||||
default_factory=lambda: [
|
|
||||||
DetectionTaskConfig(),
|
|
||||||
ClassificationTaskConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_eval_config() -> EvaluationConfig:
|
def get_default_eval_config() -> EvaluationConfig:
|
||||||
|
|||||||
@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClassificationMetricConfig]:
|
||||||
|
return [ClassificationAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
prefix: str = "classification"
|
prefix: str = "classification"
|
||||||
metrics: list[ClassificationMetricConfig] = Field(
|
metrics: list[ClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
include_generics: bool = True
|
include_generics: bool = True
|
||||||
|
|||||||
@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClipClassificationMetricConfig]:
|
||||||
|
return [ClipClassificationAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_classification"] = "clip_classification"
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
prefix: str = "clip_classification"
|
prefix: str = "clip_classification"
|
||||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_metrics
|
||||||
ClipClassificationAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClipDetectionMetricConfig]:
|
||||||
|
return [ClipDetectionAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_detection"] = "clip_detection"
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
prefix: str = "clip_detection"
|
prefix: str = "clip_detection"
|
||||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_metrics
|
||||||
ClipDetectionAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[DetectionMetricConfig]:
|
||||||
|
return [DetectionAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
prefix: str = "detection"
|
prefix: str = "detection"
|
||||||
metrics: list[DetectionMetricConfig] = Field(
|
metrics: list[DetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[TopClassMetricConfig]:
|
||||||
|
return [TopClassAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["top_class_detection"] = "top_class_detection"
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
prefix: str = "top_class"
|
prefix: str = "top_class"
|
||||||
metrics: list[TopClassMetricConfig] = Field(
|
metrics: list[TopClassMetricConfig] = Field(
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
top_class_index = int(np.argmax(prediction.class_scores))
|
top_class_index = int(np.argmax(prediction.class_scores))
|
||||||
top_class_score = float(prediction.class_scores[top_class_index])
|
top_class_score = float(prediction.class_scores[top_class_index])
|
||||||
top_class = self.get_class_name(top_class_index)
|
top_class = self.get_class_name(top_class_index)
|
||||||
return Annotation(
|
annotation: Annotation = {
|
||||||
start_time=start_time,
|
"start_time": start_time,
|
||||||
end_time=end_time,
|
"end_time": end_time,
|
||||||
low_freq=low_freq,
|
"low_freq": low_freq,
|
||||||
high_freq=high_freq,
|
"high_freq": high_freq,
|
||||||
class_prob=top_class_score,
|
"class_prob": top_class_score,
|
||||||
det_prob=float(prediction.detection_score),
|
"det_prob": float(prediction.detection_score),
|
||||||
individual="",
|
"individual": "",
|
||||||
event=self.event_name,
|
"event": self.event_name,
|
||||||
**{"class": top_class},
|
"class": top_class,
|
||||||
)
|
}
|
||||||
|
return annotation
|
||||||
|
|
||||||
@output_formatters.register(BatDetect2OutputConfig)
|
@output_formatters.register(BatDetect2OutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -26,6 +26,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
|
||||||
|
return [
|
||||||
|
PcenConfig(),
|
||||||
|
SpectralMeanSubtractionConfig(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseConfig):
|
class PreprocessingConfig(BaseConfig):
|
||||||
"""Unified configuration for the audio preprocessing pipeline.
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
|
|||||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
|
|
||||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_spectrogram_transforms
|
||||||
PcenConfig(),
|
|
||||||
SpectralMeanSubtractionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
|||||||
@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
|
|||||||
|
|
||||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||||
name="bat",
|
name="bat",
|
||||||
match_if=AllOfConfig( # ty: ignore[unknown-argument]
|
match_if=AllOfConfig(
|
||||||
conditions=[
|
conditions=[
|
||||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||||
NotConfig(
|
NotConfig(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user