Merge branch 'train' of github.com:macaodha/batdetect2 into train

This commit is contained in:
mbsantiago 2025-06-20 16:03:15 +01:00
commit ebad489cb1
4 changed files with 38 additions and 8 deletions

View File

@ -84,7 +84,7 @@ def features_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=features.detach().numpy(),
data=features.detach().cpu().numpy(),
dims=[
Dimensions.feature.value,
Dimensions.frequency.value,
@ -157,7 +157,7 @@ def detection_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(),
data=detection.squeeze(dim=0).detach().cpu().numpy(),
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
@ -233,7 +233,7 @@ def classification_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=classes.detach().numpy(),
data=classes.detach().cpu().numpy(),
dims=[
"category",
Dimensions.frequency.value,
@ -302,7 +302,7 @@ def sizes_to_xarray(
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=sizes.detach().numpy(),
data=sizes.detach().cpu().numpy(),
dims=[
"dimension",
Dimensions.frequency.value,

View File

@ -9,6 +9,7 @@ from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample
from batdetect2.utils.tensors import adjust_width
__all__ = [
"LabeledDataset",
@ -87,6 +88,26 @@ class LabeledDataset(Dataset):
)
def collate_fn(batch: List[TrainExample]):
width = 512
return TrainExample(
spec=torch.stack([adjust_width(x.spec, width) for x in batch]),
detection_heatmap=torch.stack(
[adjust_width(x.detection_heatmap, width) for x in batch]
),
class_heatmap=torch.stack(
[adjust_width(x.class_heatmap, width) for x in batch]
),
size_heatmap=torch.stack(
[adjust_width(x.size_heatmap, width) for x in batch]
),
idx=torch.stack([x.idx for x in batch]),
start_time=torch.stack([x.start_time for x in batch]),
end_time=torch.stack([x.end_time for x in batch]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc"
) -> List[Path]:

View File

@ -276,8 +276,11 @@ def preprocess_single_annotation(
labeller=labeller,
)
except Exception as error:
raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}"
) from error
logger.error(
"Failed to process annotation: {uuid}. Error {error}",
uuid=clip_annotation.uuid,
error=error,
)
return
_save_xr_dataset_to_file(sample, path)

View File

@ -17,7 +17,11 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
collate_fn,
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
@ -88,6 +92,7 @@ def train(
batch_size=config.batch_size,
shuffle=True,
num_workers=train_workers,
collate_fn=collate_fn,
)
val_dataloader = None
@ -101,6 +106,7 @@ def train(
batch_size=config.batch_size,
shuffle=False,
num_workers=val_workers,
collate_fn=collate_fn,
)
trainer.fit(