Formatting

This commit is contained in:
mbsantiago 2025-04-04 10:58:54 +01:00
parent 98c6da6d42
commit acda71ea45
6 changed files with 26 additions and 11 deletions

View File

@ -6,7 +6,7 @@ from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_encoder, get_class_names
from batdetect2.train.targets import build_target_encoder, get_class_names
def match_predictions_and_annotations(

View File

@ -29,7 +29,7 @@ from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import (
TargetConfig,
build_decoder,
build_encoder,
build_target_encoder,
get_class_names,
)
@ -78,7 +78,7 @@ class DetectorModel(L.LightningModule):
# Training targets
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder(
self.encoder = build_target_encoder(
self.config.targets.classes,
replacement_rules=self.config.targets.replace,
)

View File

@ -15,10 +15,19 @@ from batdetect2.train.dataset import (
LabeledDataset,
SubclipConfig,
TrainExample,
get_preprocessed_files,
)
from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.preprocess import preprocess_annotations
from batdetect2.train.targets import TargetConfig, load_target_config
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
)
from batdetect2.train.targets import (
TagInfo,
TargetConfig,
build_target_encoder,
load_target_config,
)
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
@ -26,12 +35,16 @@ __all__ = [
"LabelConfig",
"LabeledDataset",
"SubclipConfig",
"TagInfo",
"TargetConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"add_echo",
"augment_example",
"build_target_encoder",
"generate_train_example",
"get_preprocessed_files",
"load_agumentation_config",
"load_label_config",
"load_target_config",

View File

@ -101,7 +101,7 @@ class LabeledDataset(Dataset):
preprocessing: Optional[PreprocessingConfig] = None,
):
return cls(
get_files(directory, extension),
get_preprocessed_files(directory, extension),
subclip=subclip,
augmentation=augmentation,
preprocessing=preprocessing,
@ -143,5 +143,7 @@ class LabeledDataset(Dataset):
return adjust_width(tensor, width)
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
def get_preprocessed_files(
directory: PathLike, extension: str = ".nc"
) -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))

View File

@ -20,7 +20,7 @@ from batdetect2.preprocess import (
from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import (
TargetConfig,
build_encoder,
build_target_encoder,
build_sound_event_filter,
get_class_names,
)
@ -76,7 +76,7 @@ def generate_train_example(
event for event in clip_annotation.sound_events if filter_fn(event)
]
encoder = build_encoder(
encoder = build_target_encoder(
config.target.classes,
replacement_rules=config.target.replace,
)

View File

@ -12,7 +12,7 @@ from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [
"TargetConfig",
"load_target_config",
"build_encoder",
"build_target_encoder",
"build_decoder",
"filter_sound_event",
]
@ -91,7 +91,7 @@ def build_replacer(
return replacer
def build_encoder(
def build_target_encoder(
classes: List[TagInfo],
replacement_rules: Optional[List[ReplaceConfig]] = None,
) -> Callable[[Iterable[data.Tag]], Optional[str]]: