Merge branch 'train' into doc

This commit is contained in:
mbsantiago 2026-04-30 11:50:04 +01:00
commit 9da05c172c
7 changed files with 77 additions and 40 deletions

View File

@ -2,10 +2,6 @@
import click import click
from batdetect2.logging import enable_logging
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
__all__ = [ __all__ = [
"cli", "cli",
] ]
@ -34,5 +30,7 @@ def cli(verbose: int = 0):
""" """
click.echo(INFO_STR) click.echo(INFO_STR)
from batdetect2.logging import enable_logging
enable_logging(verbose) enable_logging(verbose)
# click.echo(BATDETECT_ASCII_ART) # click.echo(BATDETECT_ASCII_ART)

View File

@ -4,8 +4,6 @@ from typing import TYPE_CHECKING
import click import click
from loguru import logger from loguru import logger
from soundevent import io
from soundevent.audio.files import get_audio_files
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
@ -219,6 +217,8 @@ def predict_directory_command(
Loads a checkpoint, scans `audio_dir` for supported audio files, runs Loads a checkpoint, scans `audio_dir` for supported audio files, runs
inference, and saves predictions to `output_path`. inference, and saves predictions to `output_path`.
""" """
from soundevent.audio.files import get_audio_files
audio_files = list(get_audio_files(audio_dir)) audio_files = list(get_audio_files(audio_dir))
_run_prediction( _run_prediction(
model_path=model_path, model_path=model_path,
@ -309,6 +309,8 @@ def predict_dataset_command(
The dataset is read as a soundevent annotation set and unique recording The dataset is read as a soundevent annotation set and unique recording
paths are extracted before inference. paths are extracted before inference.
""" """
from soundevent import io
dataset_path = Path(dataset_path) dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set") dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted( audio_files = sorted(

View File

@ -1,4 +1,6 @@
import operator
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import partial
from typing import Annotated, Literal from typing import Annotated, Literal
from pydantic import Field from pydantic import Field
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
seconds: float seconds: float
def _build_comparator( def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
operator: Operator, value: float if op == "gt":
) -> Callable[[float], bool]: return partial(operator.lt, value)
if operator == "gt":
return lambda x: x > value
if operator == "gte": if op == "gte":
return lambda x: x >= value return partial(operator.le, value)
if operator == "lt": if op == "lt":
return lambda x: x < value return partial(operator.gt, value)
if operator == "lte": if op == "lte":
return lambda x: x <= value return partial(operator.ge, value)
if operator == "eq": if op == "eq":
return lambda x: x == value return partial(operator.eq, b=value)
raise ValueError(f"Invalid operator {operator}") raise ValueError(f"Invalid operator {op}")
class Duration: class Duration:

View File

@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_generic=self.ignore_generic, ignore_generic=self.ignore_generic,
) )
class_scores = { class_scores = {}
class_name: float(
for class_name in self.targets.class_names:
if len(y_true[class_name]) == 0:
class_scores[class_name] = np.nan
continue
class_scores[class_name] = float(
metrics.roc_auc_score( metrics.roc_auc_score(
y_true[class_name], y_true[class_name],
y_score[class_name], y_score[class_name],
) )
) )
for class_name in self.targets.class_names
}
mean_score = float( mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan]) np.mean([v for v in class_scores.values() if v != np.nan])

View File

@ -133,6 +133,9 @@ class DetectionROCAUC:
y_true.append(m.is_ground_truth) y_true.append(m.is_ground_truth)
y_score.append(m.score) y_score.append(m.score)
if len(y_true) == 0:
return {self.label: np.nan}
score = float(metrics.roc_auc_score(y_true, y_score)) score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score} return {self.label: score}

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import io import io
import sys import sys
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING,
Annotated, Annotated,
Any, Any,
Dict, Dict,
@ -13,20 +16,22 @@ from typing import (
TypeVar, TypeVar,
) )
import numpy as np from loguru import logger
import pandas as pd from pydantic import Field
from lightning.pytorch.loggers import (
from batdetect2.core.configs import BaseConfig
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger, CSVLogger,
Logger, Logger,
MLFlowLogger, MLFlowLogger,
TensorBoardLogger, TensorBoardLogger,
) )
from loguru import logger from matplotlib.figure import Figure
from matplotlib.figure import Figure from soundevent import data
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
@ -271,10 +276,16 @@ def build_logger(
) )
PlotLogger = Callable[[str, Figure, int], None] PlotLogger = Callable[[str, "Figure", int], None]
def get_image_logger(logger: Logger) -> PlotLogger | None: def get_image_logger(logger: Logger) -> PlotLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure return logger.experiment.add_figure
@ -296,10 +307,16 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
return partial(save_figure, dir=Path(logger.log_dir)) return partial(save_figure, dir=Path(logger.log_dir))
TableLogger = Callable[[str, pd.DataFrame, int], None] TableLogger = Callable[[str, "pd.DataFrame", int], None]
def get_table_logger(logger: Logger) -> TableLogger | None: def get_table_logger(logger: Logger) -> TableLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir)) return partial(save_table, dir=Path(logger.log_dir))
@ -337,6 +354,8 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
def _convert_figure_to_array(figure: Figure) -> np.ndarray: def _convert_figure_to_array(figure: Figure) -> np.ndarray:
import numpy as np
with io.BytesIO() as buff: with io.BytesIO() as buff:
figure.savefig(buff, format="raw") figure.savefig(buff, format="raw")
buff.seek(0) buff.seek(0)

View File

@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
data_source = data.Term( data_source = data.Term(
name="soundevent:data_source", name="dcterms:source",
label="Data Source", label="Source",
uri="http://purl.org/dc/terms/source",
definition=( definition=(
"A unique identifier for the source of the data, typically " "A related resource from which the described resource is derived."
"representing the project, site, or deployment context."
), ),
) )
@ -45,6 +45,17 @@ individual = data.Term(
) )
"""Term used for tags identifying a specific individual animal.""" """Term used for tags identifying a specific individual animal."""
dataset_split = data.Term(
name="batdetect2:split",
label="Dataset Split",
definition=(
"Identifies the specific data partition (e.g., 'train', 'test') "
"that the item belongs to within an experimental setup. "
"The expected value is a literal text string."
),
)
"""Custom metadata term defining the machine learning partition of an item."""
generic_class = data.Term( generic_class = data.Term(
name="soundevent:class", name="soundevent:class",
label="Class", label="Class",