mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Updating to new Augmentation object
This commit is contained in:
parent
541be15c9e
commit
285c6a3347
@ -5,16 +5,10 @@ import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.targets import (
|
||||
load_label_config,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train import load_label_config, preprocess_annotations
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
|
||||
__all__ = ["train"]
|
||||
|
||||
@ -129,9 +123,9 @@ def train(): ...
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
target_config: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
preprocess_config: Optional[Path] = None,
|
||||
target_config: Optional[Path] = None,
|
||||
label_config: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
@ -152,14 +146,10 @@ def preprocess(
|
||||
else None
|
||||
)
|
||||
|
||||
target = (
|
||||
load_target_config(
|
||||
target = load_target_config(
|
||||
target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
if target_config
|
||||
else None
|
||||
)
|
||||
|
||||
label = (
|
||||
load_label_config(
|
||||
@ -176,15 +166,18 @@ def preprocess(
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
targets = build_targets(config=target)
|
||||
preprocessor = build_preprocessor(config=preprocess)
|
||||
labeller = build_clip_labeler(targets, config=label)
|
||||
|
||||
if not output.exists():
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset.clip_annotations,
|
||||
dataset,
|
||||
output_dir=output,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
preprocessing_config=preprocess,
|
||||
label_config=label,
|
||||
target_config=target,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
|
@ -63,7 +63,7 @@ def extract_values_at_positions(
|
||||
],
|
||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||
}
|
||||
)
|
||||
).T
|
||||
|
||||
|
||||
def extract_detection_xr_dataset(
|
||||
@ -109,9 +109,9 @@ def extract_detection_xr_dataset(
|
||||
DataArrays share the 'detection' dimension and associated
|
||||
time/frequency coordinates.
|
||||
"""
|
||||
sizes = extract_values_at_positions(sizes, positions).T
|
||||
classes = extract_values_at_positions(classes, positions).T
|
||||
features = extract_values_at_positions(features, positions).T
|
||||
sizes = extract_values_at_positions(sizes, positions)
|
||||
classes = extract_values_at_positions(classes, positions)
|
||||
features = extract_values_at_positions(features, positions)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": positions,
|
||||
|
@ -1,8 +1,12 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
EchoAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
add_echo,
|
||||
augment_example,
|
||||
load_agumentation_config,
|
||||
build_augmentations,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_examples,
|
||||
@ -17,6 +21,7 @@ from batdetect2.train.dataset import (
|
||||
TrainExample,
|
||||
get_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import load_label_config
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
@ -26,17 +31,21 @@ from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"build_augmentations",
|
||||
"compute_loss",
|
||||
"generate_train_example",
|
||||
"get_preprocessed_files",
|
||||
"load_agumentation_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
@ -47,4 +56,5 @@ __all__ = [
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
"load_label_config",
|
||||
]
|
||||
|
@ -11,8 +11,8 @@ from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
Augmentation,
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessorProtocol
|
||||
@ -54,7 +54,7 @@ class LabeledDataset(Dataset):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
filenames: Sequence[PathLike],
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.filenames = filenames
|
||||
@ -76,12 +76,7 @@ class LabeledDataset(Dataset):
|
||||
)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.augmentation,
|
||||
preprocessor=self.preprocessor,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
dataset = self.augmentation(dataset)
|
||||
|
||||
return TrainExample(
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
@ -98,7 +93,7 @@ class LabeledDataset(Dataset):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
return cls(
|
||||
preprocessor=preprocessor,
|
||||
|
@ -71,7 +71,7 @@ class LabelConfig(BaseConfig):
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: TargetProtocol,
|
||||
config: LabelConfig,
|
||||
config: Optional[LabelConfig] = None,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the final clip labelling function.
|
||||
|
||||
@ -98,7 +98,7 @@ def build_clip_labeler(
|
||||
return partial(
|
||||
generate_clip_label,
|
||||
targets=targets,
|
||||
config=config,
|
||||
config=config or LabelConfig(),
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user