Compare commits

..

No commits in common. "615c811bb43fd54baf71d6ac7d758ccf2cc351f6" and "115084fd2bb2f2e141856cb6048c4782746eff7d" have entirely different histories.

7 changed files with 28 additions and 45 deletions

View File

@ -7,7 +7,6 @@ from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import ( from batdetect2.train import (
FullTrainingConfig, FullTrainingConfig,
load_full_training_config, load_full_training_config,
@ -21,7 +20,6 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True)) @click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--targets", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@ -44,7 +42,6 @@ def train_command(
ckpt_dir: Optional[Path] = None, ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
targets: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
@ -65,18 +62,12 @@ def train_command(
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading training configuration...") logger.info("Loading training configuration...")
conf = ( conf = (
load_full_training_config(config, field=config_field) load_full_training_config(config, field=config_field)
if config is not None if config is not None
else FullTrainingConfig() else FullTrainingConfig()
) )
if targets is not None:
logger.info("Loading targets configuration...")
targets_config = load_target_config(targets)
conf = conf.model_copy(update=dict(targets=targets_config))
logger.info("Loading training dataset...") logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset) train_annotations = load_dataset_from_config(train_dataset)
logger.debug( logger.debug(

View File

@ -301,8 +301,7 @@ def load_batdetect2_merged_annotated_dataset(
for ann in content: for ann in content:
try: try:
ann = FileAnnotation.model_validate(ann) ann = FileAnnotation.model_validate(ann)
except ValueError as err: except ValueError:
logger.warning(f"Invalid annotation file: {err}")
continue continue
if ( if (
@ -310,17 +309,14 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated and dataset.filter.only_annotated
and not ann.annotated and not ann.annotated
): ):
logger.debug(f"Skipping incomplete annotation {ann.id}")
continue continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues: if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}")
continue continue
try: try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir) clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError as err: except FileNotFoundError:
logger.warning(f"Error loading annotations: {err}")
continue continue
annotations.append(file_annotation_to_clip_annotation(ann, clip)) annotations.append(file_annotation_to_clip_annotation(ann, clip))

View File

@ -100,11 +100,8 @@ def extract_sound_events_df(
class_name = targets.encode_class(sound_event) class_name = targets.encode_class(sound_event)
if class_name is None: if class_name is None and exclude_generic:
if exclude_generic:
continue continue
else:
class_name = targets.detection_class_name
start_time, low_freq, end_time, high_freq = compute_bounds( start_time, low_freq, end_time, high_freq = compute_bounds(
sound_event.sound_event.geometry sound_event.sound_event.geometry
@ -156,7 +153,7 @@ def compute_class_summary(
sound_events = extract_sound_events_df( sound_events = extract_sound_events_df(
dataset, dataset,
targets, targets,
exclude_generic=False, exclude_generic=True,
exclude_non_target=True, exclude_non_target=True,
) )
recordings = extract_recordings_df(dataset) recordings = extract_recordings_df(dataset)

View File

@ -103,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [ tags = [
*get_generic_tags( *get_generic_tags(
raw_prediction.detection_score, raw_prediction.detection_score,
generic_class_tags=targets.detection_class_tags, generic_class_tags=targets.generic_class_tags,
), ),
*get_class_tags( *get_class_tags(
raw_prediction.class_scores, raw_prediction.class_scores,

View File

@ -140,18 +140,16 @@ class Targets(TargetProtocol):
""" """
class_names: List[str] class_names: List[str]
detection_class_tags: List[data.Tag] generic_class_tags: List[data.Tag]
dimension_names: List[str] dimension_names: List[str]
detection_class_name: str
def __init__( def __init__(
self, self,
detection_class_name: str,
encode_fn: SoundEventEncoder, encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder, decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper, roi_mapper: ROITargetMapper,
class_names: list[str], class_names: list[str],
detection_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventCondition] = None, filter_fn: Optional[SoundEventCondition] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None, roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
): ):
@ -177,9 +175,8 @@ class Targets(TargetProtocol):
transform_fn : SoundEventTransformation, optional transform_fn : SoundEventTransformation, optional
Configured function to transform annotation tags. Defaults to None. Configured function to transform annotation tags. Defaults to None.
""" """
self.detection_class_name = detection_class_name
self.class_names = class_names self.class_names = class_names
self.detection_class_tags = detection_class_tags self.generic_class_tags = generic_class_tags
self.dimension_names = roi_mapper.dimension_names self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper self._roi_mapper = roi_mapper
@ -384,8 +381,7 @@ def build_targets(config: Optional[TargetConfig] = None) -> Targets:
decode_fn=decode_fn, decode_fn=decode_fn,
class_names=class_names, class_names=class_names,
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
detection_class_name=config.detection_target.name, generic_class_tags=generic_class_tags,
detection_class_tags=generic_class_tags,
roi_mapper_overrides=roi_overrides, roi_mapper_overrides=roi_overrides,
) )

View File

@ -152,19 +152,25 @@ def build_trainer_callbacks(
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = DEFAULT_CHECKPOINT_DIR checkpoint_dir = DEFAULT_CHECKPOINT_DIR
if experiment_name is not None: filename = "best-{epoch:02d}-{val_loss:.0f}"
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None: if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name filename = f"run_{run_name}_{filename}"
return [ if experiment_name is not None:
ModelCheckpoint( filename = f"experiment_{experiment_name}_{filename}"
model_checkpoint = ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=1, save_top_k=1,
filename="best-{epoch:02d}-{val_loss:.0f}", filename=filename,
monitor="total_loss/val", monitor="total_loss/val",
), )
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
return [
model_checkpoint,
ValidationMetrics( ValidationMetrics(
metrics=[ metrics=[
DetectionAveragePrecision(), DetectionAveragePrecision(),
@ -214,8 +220,7 @@ def build_trainer(
config=conf.evaluation, config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess), preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name, experiment_name=train_logger.name,
run_name=run_name,
), ),
) )

View File

@ -94,10 +94,8 @@ class TargetProtocol(Protocol):
class_names: List[str] class_names: List[str]
"""Ordered list of unique names for the specific target classes.""" """Ordered list of unique names for the specific target classes."""
detection_class_tags: List[data.Tag] generic_class_tags: List[data.Tag]
"""List of tags representing the detection category (unclassified).""" """List of tags representing the generic (unclassified) category."""
detection_class_name: str
dimension_names: List[str] dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height']).""" """Names of the size dimensions (e.g., ['width', 'height'])."""