mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Merge branch 'train' of github.com:macaodha/batdetect2 into train
This commit is contained in:
commit
ebad489cb1
@ -84,7 +84,7 @@ def features_to_xarray(
|
|||||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=features.detach().numpy(),
|
data=features.detach().cpu().numpy(),
|
||||||
dims=[
|
dims=[
|
||||||
Dimensions.feature.value,
|
Dimensions.feature.value,
|
||||||
Dimensions.frequency.value,
|
Dimensions.frequency.value,
|
||||||
@ -157,7 +157,7 @@ def detection_to_xarray(
|
|||||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=detection.squeeze(dim=0).detach().numpy(),
|
data=detection.squeeze(dim=0).detach().cpu().numpy(),
|
||||||
dims=[
|
dims=[
|
||||||
Dimensions.frequency.value,
|
Dimensions.frequency.value,
|
||||||
Dimensions.time.value,
|
Dimensions.time.value,
|
||||||
@ -233,7 +233,7 @@ def classification_to_xarray(
|
|||||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=classes.detach().numpy(),
|
data=classes.detach().cpu().numpy(),
|
||||||
dims=[
|
dims=[
|
||||||
"category",
|
"category",
|
||||||
Dimensions.frequency.value,
|
Dimensions.frequency.value,
|
||||||
@ -302,7 +302,7 @@ def sizes_to_xarray(
|
|||||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=sizes.detach().numpy(),
|
data=sizes.detach().cpu().numpy(),
|
||||||
dims=[
|
dims=[
|
||||||
"dimension",
|
"dimension",
|
||||||
Dimensions.frequency.value,
|
Dimensions.frequency.value,
|
||||||
|
@ -9,6 +9,7 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
from batdetect2.train.augmentations import Augmentation
|
from batdetect2.train.augmentations import Augmentation
|
||||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||||
|
from batdetect2.utils.tensors import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
@ -87,6 +88,26 @@ class LabeledDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch: List[TrainExample]):
|
||||||
|
width = 512
|
||||||
|
|
||||||
|
return TrainExample(
|
||||||
|
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
|
||||||
|
detection_heatmap=torch.stack(
|
||||||
|
[adjust_width(x.detection_heatmap, width) for x in batch]
|
||||||
|
),
|
||||||
|
class_heatmap=torch.stack(
|
||||||
|
[adjust_width(x.class_heatmap, width) for x in batch]
|
||||||
|
),
|
||||||
|
size_heatmap=torch.stack(
|
||||||
|
[adjust_width(x.size_heatmap, width) for x in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([x.idx for x in batch]),
|
||||||
|
start_time=torch.stack([x.start_time for x in batch]),
|
||||||
|
end_time=torch.stack([x.end_time for x in batch]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
directory: data.PathLike, extension: str = ".nc"
|
directory: data.PathLike, extension: str = ".nc"
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
|
@ -276,8 +276,11 @@ def preprocess_single_annotation(
|
|||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise RuntimeError(
|
logger.error(
|
||||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
"Failed to process annotation: {uuid}. Error {error}",
|
||||||
) from error
|
uuid=clip_annotation.uuid,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
_save_xr_dataset_to_file(sample, path)
|
_save_xr_dataset_to_file(sample, path)
|
||||||
|
@ -17,7 +17,11 @@ from batdetect2.train.augmentations import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
|
from batdetect2.train.dataset import (
|
||||||
|
LabeledDataset,
|
||||||
|
RandomExampleSource,
|
||||||
|
collate_fn,
|
||||||
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
@ -88,6 +92,7 @@ def train(
|
|||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = None
|
val_dataloader = None
|
||||||
@ -101,6 +106,7 @@ def train(
|
|||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
|
Loading…
Reference in New Issue
Block a user