mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 09:29:33 +01:00
Compare commits
3 Commits
115084fd2b
...
615c811bb4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
615c811bb4 | ||
|
|
41b18c3f0a | ||
|
|
16a0fa7b75 |
@ -7,6 +7,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.targets import load_target_config
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
FullTrainingConfig,
|
FullTrainingConfig,
|
||||||
load_full_training_config,
|
load_full_training_config,
|
||||||
@ -20,6 +21,7 @@ __all__ = ["train_command"]
|
|||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
|
@click.option("--targets", type=click.Path(exists=True))
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@ -42,6 +44,7 @@ def train_command(
|
|||||||
ckpt_dir: Optional[Path] = None,
|
ckpt_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
|
targets: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -62,12 +65,18 @@ def train_command(
|
|||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading training configuration...")
|
logger.info("Loading training configuration...")
|
||||||
|
|
||||||
conf = (
|
conf = (
|
||||||
load_full_training_config(config, field=config_field)
|
load_full_training_config(config, field=config_field)
|
||||||
if config is not None
|
if config is not None
|
||||||
else FullTrainingConfig()
|
else FullTrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
logger.info("Loading targets configuration...")
|
||||||
|
targets_config = load_target_config(targets)
|
||||||
|
conf = conf.model_copy(update=dict(targets=targets_config))
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@ -301,7 +301,8 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
for ann in content:
|
for ann in content:
|
||||||
try:
|
try:
|
||||||
ann = FileAnnotation.model_validate(ann)
|
ann = FileAnnotation.model_validate(ann)
|
||||||
except ValueError:
|
except ValueError as err:
|
||||||
|
logger.warning(f"Invalid annotation file: {err}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -309,14 +310,17 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
and dataset.filter.only_annotated
|
and dataset.filter.only_annotated
|
||||||
and not ann.annotated
|
and not ann.annotated
|
||||||
):
|
):
|
||||||
|
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||||
|
logger.debug(f"Skipping annotation with issues {ann.id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError as err:
|
||||||
|
logger.warning(f"Error loading annotations: {err}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
|
|||||||
@ -100,8 +100,11 @@ def extract_sound_events_df(
|
|||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
|
|
||||||
if class_name is None and exclude_generic:
|
if class_name is None:
|
||||||
continue
|
if exclude_generic:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
class_name = targets.detection_class_name
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
sound_event.sound_event.geometry
|
sound_event.sound_event.geometry
|
||||||
@ -153,7 +156,7 @@ def compute_class_summary(
|
|||||||
sound_events = extract_sound_events_df(
|
sound_events = extract_sound_events_df(
|
||||||
dataset,
|
dataset,
|
||||||
targets,
|
targets,
|
||||||
exclude_generic=True,
|
exclude_generic=False,
|
||||||
exclude_non_target=True,
|
exclude_non_target=True,
|
||||||
)
|
)
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
tags = [
|
tags = [
|
||||||
*get_generic_tags(
|
*get_generic_tags(
|
||||||
raw_prediction.detection_score,
|
raw_prediction.detection_score,
|
||||||
generic_class_tags=targets.generic_class_tags,
|
generic_class_tags=targets.detection_class_tags,
|
||||||
),
|
),
|
||||||
*get_class_tags(
|
*get_class_tags(
|
||||||
raw_prediction.class_scores,
|
raw_prediction.class_scores,
|
||||||
|
|||||||
@ -140,16 +140,18 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
generic_class_tags: List[data.Tag]
|
detection_class_tags: List[data.Tag]
|
||||||
dimension_names: List[str]
|
dimension_names: List[str]
|
||||||
|
detection_class_name: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
detection_class_name: str,
|
||||||
encode_fn: SoundEventEncoder,
|
encode_fn: SoundEventEncoder,
|
||||||
decode_fn: SoundEventDecoder,
|
decode_fn: SoundEventDecoder,
|
||||||
roi_mapper: ROITargetMapper,
|
roi_mapper: ROITargetMapper,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
generic_class_tags: List[data.Tag],
|
detection_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventCondition] = None,
|
filter_fn: Optional[SoundEventCondition] = None,
|
||||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||||
):
|
):
|
||||||
@ -175,8 +177,9 @@ class Targets(TargetProtocol):
|
|||||||
transform_fn : SoundEventTransformation, optional
|
transform_fn : SoundEventTransformation, optional
|
||||||
Configured function to transform annotation tags. Defaults to None.
|
Configured function to transform annotation tags. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
self.detection_class_name = detection_class_name
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.generic_class_tags = generic_class_tags
|
self.detection_class_tags = detection_class_tags
|
||||||
self.dimension_names = roi_mapper.dimension_names
|
self.dimension_names = roi_mapper.dimension_names
|
||||||
|
|
||||||
self._roi_mapper = roi_mapper
|
self._roi_mapper = roi_mapper
|
||||||
@ -381,7 +384,8 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|||||||
decode_fn=decode_fn,
|
decode_fn=decode_fn,
|
||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
detection_class_name=config.detection_target.name,
|
||||||
|
detection_class_tags=generic_class_tags,
|
||||||
roi_mapper_overrides=roi_overrides,
|
roi_mapper_overrides=roi_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -152,25 +152,19 @@ def build_trainer_callbacks(
|
|||||||
if checkpoint_dir is None:
|
if checkpoint_dir is None:
|
||||||
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
filename = "best-{epoch:02d}-{val_loss:.0f}"
|
if experiment_name is not None:
|
||||||
|
checkpoint_dir = checkpoint_dir / experiment_name
|
||||||
|
|
||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
filename = f"run_{run_name}_{filename}"
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
if experiment_name is not None:
|
|
||||||
filename = f"experiment_{experiment_name}_{filename}"
|
|
||||||
|
|
||||||
model_checkpoint = ModelCheckpoint(
|
|
||||||
dirpath=str(checkpoint_dir),
|
|
||||||
save_top_k=1,
|
|
||||||
filename=filename,
|
|
||||||
monitor="total_loss/val",
|
|
||||||
)
|
|
||||||
|
|
||||||
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
model_checkpoint,
|
ModelCheckpoint(
|
||||||
|
dirpath=str(checkpoint_dir),
|
||||||
|
save_top_k=1,
|
||||||
|
filename="best-{epoch:02d}-{val_loss:.0f}",
|
||||||
|
monitor="total_loss/val",
|
||||||
|
),
|
||||||
ValidationMetrics(
|
ValidationMetrics(
|
||||||
metrics=[
|
metrics=[
|
||||||
DetectionAveragePrecision(),
|
DetectionAveragePrecision(),
|
||||||
@ -220,7 +214,8 @@ def build_trainer(
|
|||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=train_logger.name,
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -94,8 +94,10 @@ class TargetProtocol(Protocol):
|
|||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
"""Ordered list of unique names for the specific target classes."""
|
"""Ordered list of unique names for the specific target classes."""
|
||||||
|
|
||||||
generic_class_tags: List[data.Tag]
|
detection_class_tags: List[data.Tag]
|
||||||
"""List of tags representing the generic (unclassified) category."""
|
"""List of tags representing the detection category (unclassified)."""
|
||||||
|
|
||||||
|
detection_class_name: str
|
||||||
|
|
||||||
dimension_names: List[str]
|
dimension_names: List[str]
|
||||||
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user