Merge branch 'train' of github.com:macaodha/batdetect2 into train

This commit is contained in:
mbsantiago 2025-06-20 16:03:15 +01:00
commit ebad489cb1
4 changed files with 38 additions and 8 deletions

View File

@ -84,7 +84,7 @@ def features_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=features.detach().numpy(), data=features.detach().cpu().numpy(),
dims=[ dims=[
Dimensions.feature.value, Dimensions.feature.value,
Dimensions.frequency.value, Dimensions.frequency.value,
@ -157,7 +157,7 @@ def detection_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(), data=detection.squeeze(dim=0).detach().cpu().numpy(),
dims=[ dims=[
Dimensions.frequency.value, Dimensions.frequency.value,
Dimensions.time.value, Dimensions.time.value,
@ -233,7 +233,7 @@ def classification_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=classes.detach().numpy(), data=classes.detach().cpu().numpy(),
dims=[ dims=[
"category", "category",
Dimensions.frequency.value, Dimensions.frequency.value,
@ -302,7 +302,7 @@ def sizes_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False) freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray( return xr.DataArray(
data=sizes.detach().numpy(), data=sizes.detach().cpu().numpy(),
dims=[ dims=[
"dimension", "dimension",
Dimensions.frequency.value, Dimensions.frequency.value,

View File

@ -9,6 +9,7 @@ from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample from batdetect2.train.types import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width
__all__ = [ __all__ = [
"LabeledDataset", "LabeledDataset",
@ -87,6 +88,26 @@ class LabeledDataset(Dataset):
) )
def collate_fn(batch: List[TrainExample]):
width = 512
return TrainExample(
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
detection_heatmap=torch.stack(
[adjust_width(x.detection_heatmap, width) for x in batch]
),
class_heatmap=torch.stack(
[adjust_width(x.class_heatmap, width) for x in batch]
),
size_heatmap=torch.stack(
[adjust_width(x.size_heatmap, width) for x in batch]
),
idx=torch.stack([x.idx for x in batch]),
start_time=torch.stack([x.start_time for x in batch]),
end_time=torch.stack([x.end_time for x in batch]),
)
def list_preprocessed_files( def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc" directory: data.PathLike, extension: str = ".nc"
) -> List[Path]: ) -> List[Path]:

View File

@ -276,8 +276,11 @@ def preprocess_single_annotation(
labeller=labeller, labeller=labeller,
) )
except Exception as error: except Exception as error:
raise RuntimeError( logger.error(
f"Failed to process annotation: {clip_annotation.uuid}" "Failed to process annotation: {uuid}. Error {error}",
) from error uuid=clip_annotation.uuid,
error=error,
)
return
_save_xr_dataset_to_file(sample, path) _save_xr_dataset_to_file(sample, path)

View File

@ -17,7 +17,11 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
collate_fn,
)
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
@ -88,6 +92,7 @@ def train(
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
num_workers=train_workers, num_workers=train_workers,
collate_fn=collate_fn,
) )
val_dataloader = None val_dataloader = None
@ -101,6 +106,7 @@ def train(
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, shuffle=False,
num_workers=val_workers, num_workers=val_workers,
collate_fn=collate_fn,
) )
trainer.fit( trainer.fit(