mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Updating to new Augmentation object
This commit is contained in:
parent
541be15c9e
commit
285c6a3347
@ -5,16 +5,10 @@ import click
|
|||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
load_preprocessing_config,
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
)
|
from batdetect2.train import load_label_config, preprocess_annotations
|
||||||
from batdetect2.targets import (
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
load_label_config,
|
|
||||||
load_target_config,
|
|
||||||
)
|
|
||||||
from batdetect2.train import (
|
|
||||||
preprocess_annotations,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["train"]
|
__all__ = ["train"]
|
||||||
|
|
||||||
@ -129,9 +123,9 @@ def train(): ...
|
|||||||
def preprocess(
|
def preprocess(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
output: Path,
|
output: Path,
|
||||||
|
target_config: Path,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
preprocess_config: Optional[Path] = None,
|
preprocess_config: Optional[Path] = None,
|
||||||
target_config: Optional[Path] = None,
|
|
||||||
label_config: Optional[Path] = None,
|
label_config: Optional[Path] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
@ -152,13 +146,9 @@ def preprocess(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
target = (
|
target = load_target_config(
|
||||||
load_target_config(
|
target_config,
|
||||||
target_config,
|
field=target_config_field,
|
||||||
field=target_config_field,
|
|
||||||
)
|
|
||||||
if target_config
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
label = (
|
label = (
|
||||||
@ -176,15 +166,18 @@ def preprocess(
|
|||||||
base_dir=base_dir,
|
base_dir=base_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
targets = build_targets(config=target)
|
||||||
|
preprocessor = build_preprocessor(config=preprocess)
|
||||||
|
labeller = build_clip_labeler(targets, config=label)
|
||||||
|
|
||||||
if not output.exists():
|
if not output.exists():
|
||||||
output.mkdir(parents=True)
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
preprocess_annotations(
|
preprocess_annotations(
|
||||||
dataset.clip_annotations,
|
dataset,
|
||||||
output_dir=output,
|
output_dir=output,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
replace=force,
|
replace=force,
|
||||||
preprocessing_config=preprocess,
|
|
||||||
label_config=label,
|
|
||||||
target_config=target,
|
|
||||||
max_workers=num_workers,
|
max_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
@ -63,7 +63,7 @@ def extract_values_at_positions(
|
|||||||
],
|
],
|
||||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||||
}
|
}
|
||||||
)
|
).T
|
||||||
|
|
||||||
|
|
||||||
def extract_detection_xr_dataset(
|
def extract_detection_xr_dataset(
|
||||||
@ -109,9 +109,9 @@ def extract_detection_xr_dataset(
|
|||||||
DataArrays share the 'detection' dimension and associated
|
DataArrays share the 'detection' dimension and associated
|
||||||
time/frequency coordinates.
|
time/frequency coordinates.
|
||||||
"""
|
"""
|
||||||
sizes = extract_values_at_positions(sizes, positions).T
|
sizes = extract_values_at_positions(sizes, positions)
|
||||||
classes = extract_values_at_positions(classes, positions).T
|
classes = extract_values_at_positions(classes, positions)
|
||||||
features = extract_values_at_positions(features, positions).T
|
features = extract_values_at_positions(features, positions)
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
{
|
{
|
||||||
"scores": positions,
|
"scores": positions,
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
|
EchoAugmentationConfig,
|
||||||
|
FrequencyMaskAugmentationConfig,
|
||||||
|
TimeMaskAugmentationConfig,
|
||||||
|
VolumeAugmentationConfig,
|
||||||
|
WarpAugmentationConfig,
|
||||||
add_echo,
|
add_echo,
|
||||||
augment_example,
|
build_augmentations,
|
||||||
load_agumentation_config,
|
|
||||||
mask_frequency,
|
mask_frequency,
|
||||||
mask_time,
|
mask_time,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
@ -17,6 +21,7 @@ from batdetect2.train.dataset import (
|
|||||||
TrainExample,
|
TrainExample,
|
||||||
get_preprocessed_files,
|
get_preprocessed_files,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.labels import load_label_config
|
||||||
from batdetect2.train.losses import compute_loss
|
from batdetect2.train.losses import compute_loss
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
@ -26,17 +31,21 @@ from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
|
"EchoAugmentationConfig",
|
||||||
|
"FrequencyMaskAugmentationConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"SubclipConfig",
|
"SubclipConfig",
|
||||||
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
|
"VolumeAugmentationConfig",
|
||||||
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"augment_example",
|
"build_augmentations",
|
||||||
"compute_loss",
|
"compute_loss",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"get_preprocessed_files",
|
"get_preprocessed_files",
|
||||||
"load_agumentation_config",
|
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"load_trainer_config",
|
"load_trainer_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
@ -47,4 +56,5 @@ __all__ = [
|
|||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
|
"load_label_config",
|
||||||
]
|
]
|
||||||
|
@ -11,8 +11,8 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
|
Augmentation,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
augment_example,
|
|
||||||
select_subclip,
|
select_subclip,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessorProtocol
|
from batdetect2.train.preprocess import PreprocessorProtocol
|
||||||
@ -54,7 +54,7 @@ class LabeledDataset(Dataset):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
subclip: Optional[SubclipConfig] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
augmentation: Optional[AugmentationsConfig] = None,
|
augmentation: Optional[Augmentation] = None,
|
||||||
):
|
):
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
@ -76,12 +76,7 @@ class LabeledDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.augmentation:
|
if self.augmentation:
|
||||||
dataset = augment_example(
|
dataset = self.augmentation(dataset)
|
||||||
dataset,
|
|
||||||
self.augmentation,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
others=self.get_random_example,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||||
@ -98,7 +93,7 @@ class LabeledDataset(Dataset):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
extension: str = ".nc",
|
extension: str = ".nc",
|
||||||
subclip: Optional[SubclipConfig] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
augmentation: Optional[AugmentationsConfig] = None,
|
augmentation: Optional[Augmentation] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
@ -71,7 +71,7 @@ class LabelConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: LabelConfig,
|
config: Optional[LabelConfig] = None,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the final clip labelling function.
|
"""Construct the final clip labelling function.
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def build_clip_labeler(
|
|||||||
return partial(
|
return partial(
|
||||||
generate_clip_label,
|
generate_clip_label,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config,
|
config=config or LabelConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user