Merge branch 'train' into doc

This commit is contained in:
mbsantiago 2026-04-30 11:50:04 +01:00
commit 9da05c172c
7 changed files with 77 additions and 40 deletions

View File

@ -2,10 +2,6 @@
import click
from batdetect2.logging import enable_logging
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
__all__ = [
"cli",
]
@ -34,5 +30,7 @@ def cli(verbose: int = 0):
"""
click.echo(INFO_STR)
from batdetect2.logging import enable_logging
enable_logging(verbose)
# click.echo(BATDETECT_ASCII_ART)

View File

@ -4,8 +4,6 @@ from typing import TYPE_CHECKING
import click
from loguru import logger
from soundevent import io
from soundevent.audio.files import get_audio_files
from batdetect2.cli.base import cli
@ -219,6 +217,8 @@ def predict_directory_command(
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
inference, and saves predictions to `output_path`.
"""
from soundevent.audio.files import get_audio_files
audio_files = list(get_audio_files(audio_dir))
_run_prediction(
model_path=model_path,
@ -309,6 +309,8 @@ def predict_dataset_command(
The dataset is read as a soundevent annotation set and unique recording
paths are extracted before inference.
"""
from soundevent import io
dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted(

View File

@ -1,4 +1,6 @@
import operator
from collections.abc import Callable, Sequence
from functools import partial
from typing import Annotated, Literal
from pydantic import Field
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
if op == "gt":
return partial(operator.lt, value)
if operator == "gte":
return lambda x: x >= value
if op == "gte":
return partial(operator.le, value)
if operator == "lt":
return lambda x: x < value
if op == "lt":
return partial(operator.gt, value)
if operator == "lte":
return lambda x: x <= value
if op == "lte":
return partial(operator.ge, value)
if operator == "eq":
return lambda x: x == value
if op == "eq":
return partial(operator.eq, b=value)
raise ValueError(f"Invalid operator {operator}")
raise ValueError(f"Invalid operator {op}")
class Duration:

View File

@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
class_scores = {}
for class_name in self.targets.class_names:
if len(y_true[class_name]) == 0:
class_scores[class_name] = np.nan
continue
class_scores[class_name] = float(
metrics.roc_auc_score(
y_true[class_name],
y_score[class_name],
)
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])

View File

@ -133,6 +133,9 @@ class DetectionROCAUC:
y_true.append(m.is_ground_truth)
y_score.append(m.score)
if len(y_true) == 0:
return {self.label: np.nan}
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import io
import sys
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Dict,
@ -13,21 +16,23 @@ from typing import (
TypeVar,
)
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from matplotlib.figure import Figure
from soundevent import data
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
@ -271,10 +276,16 @@ def build_logger(
)
PlotLogger = Callable[[str, Figure, int], None]
PlotLogger = Callable[[str, "Figure", int], None]
def get_image_logger(logger: Logger) -> PlotLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
@ -296,10 +307,16 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
return partial(save_figure, dir=Path(logger.log_dir))
TableLogger = Callable[[str, pd.DataFrame, int], None]
TableLogger = Callable[[str, "pd.DataFrame", int], None]
def get_table_logger(logger: Logger) -> TableLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))
@ -337,6 +354,8 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
import numpy as np
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)

View File

@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
data_source = data.Term(
name="soundevent:data_source",
label="Data Source",
name="dcterms:source",
label="Source",
uri="http://purl.org/dc/terms/source",
definition=(
"A unique identifier for the source of the data, typically "
"representing the project, site, or deployment context."
"A related resource from which the described resource is derived."
),
)
@ -45,6 +45,17 @@ individual = data.Term(
)
"""Term used for tags identifying a specific individual animal."""
dataset_split = data.Term(
name="batdetect2:split",
label="Dataset Split",
definition=(
"Identifies the specific data partition (e.g., 'train', 'test') "
"that the item belongs to within an experimental setup. "
"The expected value is a literal text string."
),
)
"""Custom metadata term defining the machine learning partition of an item."""
generic_class = data.Term(
name="soundevent:class",
label="Class",