mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Formatting
This commit is contained in:
parent
98c6da6d42
commit
acda71ea45
@ -6,7 +6,7 @@ from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_encoder, get_class_names
|
||||
from batdetect2.train.targets import build_target_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
|
@ -29,7 +29,7 @@ from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_encoder,
|
||||
build_target_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ class DetectorModel(L.LightningModule):
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_encoder(
|
||||
self.encoder = build_target_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
|
@ -15,10 +15,19 @@ from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
get_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.preprocess import preprocess_annotations
|
||||
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.targets import (
|
||||
TagInfo,
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
@ -26,12 +35,16 @@ __all__ = [
|
||||
"LabelConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TagInfo",
|
||||
"TargetConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"build_target_encoder",
|
||||
"generate_train_example",
|
||||
"get_preprocessed_files",
|
||||
"load_agumentation_config",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
|
@ -101,7 +101,7 @@ class LabeledDataset(Dataset):
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
get_preprocessed_files(directory, extension),
|
||||
subclip=subclip,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
@ -143,5 +143,7 @@ class LabeledDataset(Dataset):
|
||||
return adjust_width(tensor, width)
|
||||
|
||||
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
def get_preprocessed_files(
|
||||
directory: PathLike, extension: str = ".nc"
|
||||
) -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
@ -20,7 +20,7 @@ from batdetect2.preprocess import (
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_encoder,
|
||||
build_target_encoder,
|
||||
build_sound_event_filter,
|
||||
get_class_names,
|
||||
)
|
||||
@ -76,7 +76,7 @@ def generate_train_example(
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
encoder = build_encoder(
|
||||
encoder = build_target_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
"build_encoder",
|
||||
"build_target_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
@ -91,7 +91,7 @@ def build_replacer(
|
||||
return replacer
|
||||
|
||||
|
||||
def build_encoder(
|
||||
def build_target_encoder(
|
||||
classes: List[TagInfo],
|
||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
|
Loading…
Reference in New Issue
Block a user