mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Formatting
This commit is contained in:
parent
98c6da6d42
commit
acda71ea45
@ -6,7 +6,7 @@ from sklearn.metrics import auc, roc_curve
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_geometries
|
from soundevent.evaluation import match_geometries
|
||||||
|
|
||||||
from batdetect2.train.targets import build_encoder, get_class_names
|
from batdetect2.train.targets import build_target_encoder, get_class_names
|
||||||
|
|
||||||
|
|
||||||
def match_predictions_and_annotations(
|
def match_predictions_and_annotations(
|
||||||
|
@ -29,7 +29,7 @@ from batdetect2.train.losses import compute_loss
|
|||||||
from batdetect2.train.targets import (
|
from batdetect2.train.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_decoder,
|
build_decoder,
|
||||||
build_encoder,
|
build_target_encoder,
|
||||||
get_class_names,
|
get_class_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class DetectorModel(L.LightningModule):
|
|||||||
|
|
||||||
# Training targets
|
# Training targets
|
||||||
self.class_names = get_class_names(self.config.targets.classes)
|
self.class_names = get_class_names(self.config.targets.classes)
|
||||||
self.encoder = build_encoder(
|
self.encoder = build_target_encoder(
|
||||||
self.config.targets.classes,
|
self.config.targets.classes,
|
||||||
replacement_rules=self.config.targets.replace,
|
replacement_rules=self.config.targets.replace,
|
||||||
)
|
)
|
||||||
|
@ -15,10 +15,19 @@ from batdetect2.train.dataset import (
|
|||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
SubclipConfig,
|
SubclipConfig,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
|
get_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||||
from batdetect2.train.preprocess import preprocess_annotations
|
from batdetect2.train.preprocess import (
|
||||||
from batdetect2.train.targets import TargetConfig, load_target_config
|
generate_train_example,
|
||||||
|
preprocess_annotations,
|
||||||
|
)
|
||||||
|
from batdetect2.train.targets import (
|
||||||
|
TagInfo,
|
||||||
|
TargetConfig,
|
||||||
|
build_target_encoder,
|
||||||
|
load_target_config,
|
||||||
|
)
|
||||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -26,12 +35,16 @@ __all__ = [
|
|||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"SubclipConfig",
|
"SubclipConfig",
|
||||||
|
"TagInfo",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"augment_example",
|
"augment_example",
|
||||||
|
"build_target_encoder",
|
||||||
|
"generate_train_example",
|
||||||
|
"get_preprocessed_files",
|
||||||
"load_agumentation_config",
|
"load_agumentation_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
@ -101,7 +101,7 @@ class LabeledDataset(Dataset):
|
|||||||
preprocessing: Optional[PreprocessingConfig] = None,
|
preprocessing: Optional[PreprocessingConfig] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
get_files(directory, extension),
|
get_preprocessed_files(directory, extension),
|
||||||
subclip=subclip,
|
subclip=subclip,
|
||||||
augmentation=augmentation,
|
augmentation=augmentation,
|
||||||
preprocessing=preprocessing,
|
preprocessing=preprocessing,
|
||||||
@ -143,5 +143,7 @@ class LabeledDataset(Dataset):
|
|||||||
return adjust_width(tensor, width)
|
return adjust_width(tensor, width)
|
||||||
|
|
||||||
|
|
||||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
def get_preprocessed_files(
|
||||||
|
directory: PathLike, extension: str = ".nc"
|
||||||
|
) -> Sequence[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
@ -20,7 +20,7 @@ from batdetect2.preprocess import (
|
|||||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||||
from batdetect2.train.targets import (
|
from batdetect2.train.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_encoder,
|
build_target_encoder,
|
||||||
build_sound_event_filter,
|
build_sound_event_filter,
|
||||||
get_class_names,
|
get_class_names,
|
||||||
)
|
)
|
||||||
@ -76,7 +76,7 @@ def generate_train_example(
|
|||||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||||
]
|
]
|
||||||
|
|
||||||
encoder = build_encoder(
|
encoder = build_target_encoder(
|
||||||
config.target.classes,
|
config.target.classes,
|
||||||
replacement_rules=config.target.replace,
|
replacement_rules=config.target.replace,
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@ from batdetect2.terms import TagInfo, get_tag_from_info
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
"build_encoder",
|
"build_target_encoder",
|
||||||
"build_decoder",
|
"build_decoder",
|
||||||
"filter_sound_event",
|
"filter_sound_event",
|
||||||
]
|
]
|
||||||
@ -91,7 +91,7 @@ def build_replacer(
|
|||||||
return replacer
|
return replacer
|
||||||
|
|
||||||
|
|
||||||
def build_encoder(
|
def build_target_encoder(
|
||||||
classes: List[TagInfo],
|
classes: List[TagInfo],
|
||||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user