mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Merge branch 'train' into doc
This commit is contained in:
commit
9da05c172c
@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2.logging import enable_logging
|
|
||||||
|
|
||||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
]
|
]
|
||||||
@ -34,5 +30,7 @@ def cli(verbose: int = 0):
|
|||||||
"""
|
"""
|
||||||
click.echo(INFO_STR)
|
click.echo(INFO_STR)
|
||||||
|
|
||||||
|
from batdetect2.logging import enable_logging
|
||||||
|
|
||||||
enable_logging(verbose)
|
enable_logging(verbose)
|
||||||
# click.echo(BATDETECT_ASCII_ART)
|
# click.echo(BATDETECT_ASCII_ART)
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import io
|
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
|
||||||
@ -219,6 +217,8 @@ def predict_directory_command(
|
|||||||
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
||||||
inference, and saves predictions to `output_path`.
|
inference, and saves predictions to `output_path`.
|
||||||
"""
|
"""
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
audio_files = list(get_audio_files(audio_dir))
|
audio_files = list(get_audio_files(audio_dir))
|
||||||
_run_prediction(
|
_run_prediction(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
@ -309,6 +309,8 @@ def predict_dataset_command(
|
|||||||
The dataset is read as a soundevent annotation set and unique recording
|
The dataset is read as a soundevent annotation set and unique recording
|
||||||
paths are extracted before inference.
|
paths are extracted before inference.
|
||||||
"""
|
"""
|
||||||
|
from soundevent import io
|
||||||
|
|
||||||
dataset_path = Path(dataset_path)
|
dataset_path = Path(dataset_path)
|
||||||
dataset = io.load(dataset_path, type="annotation_set")
|
dataset = io.load(dataset_path, type="annotation_set")
|
||||||
audio_files = sorted(
|
audio_files = sorted(
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
|
import operator
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from functools import partial
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
|
|||||||
seconds: float
|
seconds: float
|
||||||
|
|
||||||
|
|
||||||
def _build_comparator(
|
def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
|
||||||
operator: Operator, value: float
|
if op == "gt":
|
||||||
) -> Callable[[float], bool]:
|
return partial(operator.lt, value)
|
||||||
if operator == "gt":
|
|
||||||
return lambda x: x > value
|
|
||||||
|
|
||||||
if operator == "gte":
|
if op == "gte":
|
||||||
return lambda x: x >= value
|
return partial(operator.le, value)
|
||||||
|
|
||||||
if operator == "lt":
|
if op == "lt":
|
||||||
return lambda x: x < value
|
return partial(operator.gt, value)
|
||||||
|
|
||||||
if operator == "lte":
|
if op == "lte":
|
||||||
return lambda x: x <= value
|
return partial(operator.ge, value)
|
||||||
|
|
||||||
if operator == "eq":
|
if op == "eq":
|
||||||
return lambda x: x == value
|
return partial(operator.eq, b=value)
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {operator}")
|
raise ValueError(f"Invalid operator {op}")
|
||||||
|
|
||||||
|
|
||||||
class Duration:
|
class Duration:
|
||||||
|
|||||||
@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
ignore_generic=self.ignore_generic,
|
ignore_generic=self.ignore_generic,
|
||||||
)
|
)
|
||||||
|
|
||||||
class_scores = {
|
class_scores = {}
|
||||||
class_name: float(
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
if len(y_true[class_name]) == 0:
|
||||||
|
class_scores[class_name] = np.nan
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_scores[class_name] = float(
|
||||||
metrics.roc_auc_score(
|
metrics.roc_auc_score(
|
||||||
y_true[class_name],
|
y_true[class_name],
|
||||||
y_score[class_name],
|
y_score[class_name],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
mean_score = float(
|
mean_score = float(
|
||||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||||
|
|||||||
@ -133,6 +133,9 @@ class DetectionROCAUC:
|
|||||||
y_true.append(m.is_ground_truth)
|
y_true.append(m.is_ground_truth)
|
||||||
y_score.append(m.score)
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
if len(y_true) == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
return {self.label: score}
|
return {self.label: score}
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
@ -13,6 +16,12 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from lightning.pytorch.loggers import (
|
from lightning.pytorch.loggers import (
|
||||||
@ -21,13 +30,9 @@ from lightning.pytorch.loggers import (
|
|||||||
MLFlowLogger,
|
MLFlowLogger,
|
||||||
TensorBoardLogger,
|
TensorBoardLogger,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -271,10 +276,16 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
PlotLogger = Callable[[str, Figure, int], None]
|
PlotLogger = Callable[[str, "Figure", int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||||
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return logger.experiment.add_figure
|
return logger.experiment.add_figure
|
||||||
|
|
||||||
@ -296,10 +307,16 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
|
|||||||
return partial(save_figure, dir=Path(logger.log_dir))
|
return partial(save_figure, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
|
||||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
TableLogger = Callable[[str, "pd.DataFrame", int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||||
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
@ -337,6 +354,8 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
with io.BytesIO() as buff:
|
with io.BytesIO() as buff:
|
||||||
figure.savefig(buff, format="raw")
|
figure.savefig(buff, format="raw")
|
||||||
buff.seek(0)
|
buff.seek(0)
|
||||||
|
|||||||
@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
|
|||||||
|
|
||||||
|
|
||||||
data_source = data.Term(
|
data_source = data.Term(
|
||||||
name="soundevent:data_source",
|
name="dcterms:source",
|
||||||
label="Data Source",
|
label="Source",
|
||||||
|
uri="http://purl.org/dc/terms/source",
|
||||||
definition=(
|
definition=(
|
||||||
"A unique identifier for the source of the data, typically "
|
"A related resource from which the described resource is derived."
|
||||||
"representing the project, site, or deployment context."
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,6 +45,17 @@ individual = data.Term(
|
|||||||
)
|
)
|
||||||
"""Term used for tags identifying a specific individual animal."""
|
"""Term used for tags identifying a specific individual animal."""
|
||||||
|
|
||||||
|
dataset_split = data.Term(
|
||||||
|
name="batdetect2:split",
|
||||||
|
label="Dataset Split",
|
||||||
|
definition=(
|
||||||
|
"Identifies the specific data partition (e.g., 'train', 'test') "
|
||||||
|
"that the item belongs to within an experimental setup. "
|
||||||
|
"The expected value is a literal text string."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""Custom metadata term defining the machine learning partition of an item."""
|
||||||
|
|
||||||
generic_class = data.Term(
|
generic_class = data.Term(
|
||||||
name="soundevent:class",
|
name="soundevent:class",
|
||||||
label="Class",
|
label="Class",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user