Formatting

This commit is contained in:
mbsantiago 2025-04-04 10:58:54 +01:00
parent 98c6da6d42
commit acda71ea45
6 changed files with 26 additions and 11 deletions

View File

@ -6,7 +6,7 @@ from sklearn.metrics import auc, roc_curve
from soundevent import data from soundevent import data
from soundevent.evaluation import match_geometries from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_encoder, get_class_names from batdetect2.train.targets import build_target_encoder, get_class_names
def match_predictions_and_annotations( def match_predictions_and_annotations(

View File

@ -29,7 +29,7 @@ from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_decoder, build_decoder,
build_encoder, build_target_encoder,
get_class_names, get_class_names,
) )
@ -78,7 +78,7 @@ class DetectorModel(L.LightningModule):
# Training targets # Training targets
self.class_names = get_class_names(self.config.targets.classes) self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder( self.encoder = build_target_encoder(
self.config.targets.classes, self.config.targets.classes,
replacement_rules=self.config.targets.replace, replacement_rules=self.config.targets.replace,
) )

View File

@ -15,10 +15,19 @@ from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
SubclipConfig, SubclipConfig,
TrainExample, TrainExample,
get_preprocessed_files,
) )
from batdetect2.train.labels import LabelConfig, load_label_config from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.preprocess import preprocess_annotations from batdetect2.train.preprocess import (
from batdetect2.train.targets import TargetConfig, load_target_config generate_train_example,
preprocess_annotations,
)
from batdetect2.train.targets import (
TagInfo,
TargetConfig,
build_target_encoder,
load_target_config,
)
from batdetect2.train.train import TrainerConfig, load_trainer_config, train from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [ __all__ = [
@ -26,12 +35,16 @@ __all__ = [
"LabelConfig", "LabelConfig",
"LabeledDataset", "LabeledDataset",
"SubclipConfig", "SubclipConfig",
"TagInfo",
"TargetConfig", "TargetConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",
"TrainingConfig", "TrainingConfig",
"add_echo", "add_echo",
"augment_example", "augment_example",
"build_target_encoder",
"generate_train_example",
"get_preprocessed_files",
"load_agumentation_config", "load_agumentation_config",
"load_label_config", "load_label_config",
"load_target_config", "load_target_config",

View File

@ -101,7 +101,7 @@ class LabeledDataset(Dataset):
preprocessing: Optional[PreprocessingConfig] = None, preprocessing: Optional[PreprocessingConfig] = None,
): ):
return cls( return cls(
get_files(directory, extension), get_preprocessed_files(directory, extension),
subclip=subclip, subclip=subclip,
augmentation=augmentation, augmentation=augmentation,
preprocessing=preprocessing, preprocessing=preprocessing,
@ -143,5 +143,7 @@ class LabeledDataset(Dataset):
return adjust_width(tensor, width) return adjust_width(tensor, width)
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: def get_preprocessed_files(
directory: PathLike, extension: str = ".nc"
) -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}")) return list(Path(directory).glob(f"*{extension}"))

View File

@ -20,7 +20,7 @@ from batdetect2.preprocess import (
from batdetect2.train.labels import LabelConfig, generate_heatmaps from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_encoder, build_target_encoder,
build_sound_event_filter, build_sound_event_filter,
get_class_names, get_class_names,
) )
@ -76,7 +76,7 @@ def generate_train_example(
event for event in clip_annotation.sound_events if filter_fn(event) event for event in clip_annotation.sound_events if filter_fn(event)
] ]
encoder = build_encoder( encoder = build_target_encoder(
config.target.classes, config.target.classes,
replacement_rules=config.target.replace, replacement_rules=config.target.replace,
) )

View File

@ -12,7 +12,7 @@ from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [ __all__ = [
"TargetConfig", "TargetConfig",
"load_target_config", "load_target_config",
"build_encoder", "build_target_encoder",
"build_decoder", "build_decoder",
"filter_sound_event", "filter_sound_event",
] ]
@ -91,7 +91,7 @@ def build_replacer(
return replacer return replacer
def build_encoder( def build_target_encoder(
classes: List[TagInfo], classes: List[TagInfo],
replacement_rules: Optional[List[ReplaceConfig]] = None, replacement_rules: Optional[List[ReplaceConfig]] = None,
) -> Callable[[Iterable[data.Tag]], Optional[str]]: ) -> Callable[[Iterable[data.Tag]], Optional[str]]: