Adding evaluation callback

This commit is contained in:
mbsantiago 2025-04-25 17:12:57 +01:00
parent 9106b9f408
commit bc86c94f8e
18 changed files with 430 additions and 169 deletions

View File

@ -1,9 +1,13 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
__all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
]

View File

@ -1,51 +1,6 @@
from typing import List
import numpy as np
from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches
def compute_error_auc(op_str, gt, pred, prob):

View File

@ -0,0 +1,111 @@
from typing import List
from soundevent import data
from soundevent.evaluation import match_geometries
from batdetect2.evaluate.types import Match
from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
def match_sound_events_and_raw_predictions(
sound_events: List[data.SoundEventAnnotation],
raw_predictions: List[RawPrediction],
targets: TargetProtocol,
) -> List[Match]:
target_sound_events = [
targets.transform(sound_event_annotation)
for sound_event_annotation in sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
]
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
matches = []
for id1, id2, affinity in match_geometries(
target_geometries,
predicted_geometries,
):
target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
class_scores = (
{
str(class_name): float(score)
for class_name, score in iterate_over_array(
prediction.class_scores
)
}
if prediction is not None
else {}
)
matches.append(
Match(
gt_uuid=gt_uuid,
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
affinity=affinity,
class_scores=class_scores,
)
)
return matches
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches

View File

@ -0,0 +1,39 @@
from typing import List
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
name: str = "detection/average_precision"
def __call__(self, matches: List[Match]) -> float:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
return float(metrics.average_precision_score(y_true, y_score))
class ClassificationMeanAveragePrecision(MetricsProtocol):
name: str = "classification/average_precision"
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> float:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
],
classes=self.class_names,
)
y_pred = pd.DataFrame([match.class_scores for match in matches])
return float(
metrics.average_precision_score(y_true, y_pred[self.class_names])
)

View File

@ -0,0 +1,24 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from uuid import UUID
__all__ = [
"MetricsProtocol",
"Match",
]
@dataclass
class Match:
gt_uuid: Optional[UUID]
gt_det: bool
gt_class: Optional[str]
pred_score: float
affinity: float
class_scores: Dict[str, float]
class MetricsProtocol(Protocol):
name: str
def __call__(self, matches: List[Match]) -> float: ...

View File

@ -170,8 +170,8 @@ def load_postprocess_config(
def build_postprocessor(
targets: TargetProtocol,
config: Optional[PostprocessConfig] = None,
max_freq: int = MAX_FREQ,
min_freq: int = MIN_FREQ,
max_freq: float = MAX_FREQ,
min_freq: float = MIN_FREQ,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor.
@ -234,9 +234,9 @@ class Postprocessor(PostprocessorProtocol):
recovery.
config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc.
min_freq : int
min_freq : float
Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : int
max_freq : float
Maximum frequency (Hz) assumed for the model output's frequency axis.
"""
@ -246,8 +246,8 @@ class Postprocessor(PostprocessorProtocol):
self,
targets: TargetProtocol,
config: PostprocessConfig,
min_freq: int = MIN_FREQ,
max_freq: int = MAX_FREQ,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
):
"""Initialize the Postprocessor.

View File

@ -32,10 +32,10 @@ from typing import List, Optional
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
__all__ = [
"convert_xr_dataset_to_raw_prediction",
@ -97,18 +97,14 @@ def convert_xr_dataset_to_raw_prediction(
det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder(
(det_info.time, det_info.freq),
(det_info.time, det_info.frequency),
det_info.dimensions,
)
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
detections.append(
RawPrediction(
detection_score=det_info.score,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
detection_score=det_info.scores,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,
)
@ -244,14 +240,7 @@ def convert_raw_prediction_to_sound_event_prediction(
"""
sound_event = data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(
coordinates=[
raw_prediction.start_time,
raw_prediction.low_freq,
raw_prediction.end_time,
raw_prediction.high_freq,
]
),
geometry=raw_prediction.geometry,
features=get_prediction_features(raw_prediction.features),
)
@ -333,7 +322,7 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
),
value=value,
)
for feat_name, value in _iterate_over_array(features)
for feat_name, value in iterate_over_array(features)
]
@ -394,13 +383,6 @@ def get_class_tags(
return tags
def _iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values

View File

@ -47,14 +47,9 @@ class RawPrediction(NamedTuple):
Attributes
----------
start_time : float
Start time of the recovered bounding box in seconds.
end_time : float
End time of the recovered bounding box in seconds.
low_freq : float
Lowest frequency of the recovered bounding box in Hz.
high_freq : float
Highest frequency of the recovered bounding box in Hz.
geometry: data.Geometry
The recovered estimated geometry of the detected sound event.
Usually a bounding box.
detection_score : float
The confidence score associated with this detection, typically from
the detection heatmap peak.
@ -67,10 +62,7 @@ class RawPrediction(NamedTuple):
detection location. Indexed by a 'feature' coordinate.
"""
start_time: float
end_time: float
low_freq: float
high_freq: float
geometry: data.Geometry
detection_score: float
class_scores: xr.DataArray
features: xr.DataArray

View File

@ -106,7 +106,7 @@ def contains_tags(
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags < sound_event_tags
return tags <= sound_event_tags
def does_not_have_tags(

View File

@ -20,14 +20,27 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
handled in `batdetect2.targets.classes`.
"""
from typing import List, Optional, Protocol, Tuple
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
from soundevent import data, geometry
from soundevent.geometry.operations import Positions
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
Positions = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
__all__ = [
"ROITargetMapper",
"ROIConfig",
@ -242,6 +255,8 @@ class BBoxEncoder(ROITargetMapper):
Tuple[float, float]
Reference position (time, frequency).
"""
from soundevent import geometry
return geometry.get_geometry_point(geom, position=self.position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
@ -260,6 +275,8 @@ class BBoxEncoder(ROITargetMapper):
np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`.
"""
from soundevent import geometry
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
@ -308,8 +325,8 @@ class BBoxEncoder(ROITargetMapper):
width, height = dims
return _build_bounding_box(
pos,
duration=width / self.time_scale,
bandwidth=height / self.frequency_scale,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
position=self.position,
)
@ -421,14 +438,16 @@ def _build_bounding_box(
ValueError
If `position` is not a recognized value or format.
"""
time, freq = pos
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if position in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
time - duration / 2,
freq - bandwidth / 2,
time + duration / 2,
freq + bandwidth / 2,
max(time - duration / 2, 0),
max(freq - bandwidth / 2, 0),
max(time + duration / 2, 0),
max(freq + bandwidth / 2, 0),
]
)
@ -454,9 +473,9 @@ def _build_bounding_box(
return data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
max(0, start_time),
max(0, low_freq),
max(0, start_time + duration),
max(0, low_freq + bandwidth),
]
)

View File

@ -14,20 +14,28 @@ from batdetect2.train.augmentations import (
warp_spectrogram,
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.config import (
TrainerConfig,
TrainingConfig,
load_train_config,
)
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
TrainExample,
list_preprocessed_files,
)
from batdetect2.train.labels import load_label_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.losses import LossFunction, build_loss
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
)
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
from batdetect2.train.train import (
build_train_dataset,
build_val_dataset,
train,
)
__all__ = [
"AugmentationsConfig",
@ -44,13 +52,15 @@ __all__ = [
"WarpAugmentationConfig",
"add_echo",
"build_augmentations",
"build_clip_labeler",
"build_clipper",
"build_loss",
"build_train_dataset",
"build_val_dataset",
"generate_train_example",
"list_preprocessed_files",
"load_label_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
"mask_time",
"mix_examples",
@ -58,5 +68,6 @@ __all__ = [
"scale_volume",
"select_subclip",
"train",
"train",
"warp_spectrogram",
]

View File

@ -1,30 +1,51 @@
from typing import List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.postprocess import PostprocessorProtocol
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.types import ModelOutput
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, postprocessor: PostprocessorProtocol):
def __init__(self, metrics: List[MetricsProtocol]):
super().__init__()
self.postprocessor = postprocessor
self.predictions = []
if len(metrics) == 0:
raise ValueError("At least one metric needs to be provided")
self.matches: List[Match] = []
self.metrics = metrics
def on_validation_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
for metric in self.metrics:
value = metric(self.matches)
pl_module.log(f"val/metric/{metric.name}", value, prog_bar=True)
return super().on_validation_epoch_end(trainer, pl_module)
def on_validation_epoch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.predictions = []
self.matches = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: LightningModule,
pl_module: TrainingModule,
outputs: ModelOutput,
batch: TrainExample,
batch_idx: int,
@ -32,24 +53,73 @@ class ValidationMetrics(Callback):
) -> None:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
# clip_prediction = postprocess_model_outputs(
# outputs,
# clips=[clip_annotation.clip],
# classes=self.class_names,
# decoder=self.decoder,
# config=self.config.postprocessing,
# )[0]
#
# matches = match_predictions_and_annotations(
# clip_annotation,
# clip_prediction,
# )
#
# self.validation_predictions.extend(matches)
# return super().on_validation_batch_end(
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
# )
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=pl_module.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_raw_predictions(
outputs,
clips,
)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self.matches.extend(
match_sound_events_and_raw_predictions(
sound_events=clip_annotation.sound_events,
raw_predictions=clip_predictions,
targets=pl_module.targets,
)
)
def _is_in_subclip(
sound_event_annotation: data.SoundEventAnnotation,
targets: TargetProtocol,
start_time: float,
end_time: float,
) -> bool:
time, _ = targets.get_position(sound_event_annotation)
return start_time <= time <= end_time
def _get_subclip(
clip_annotation: data.ClipAnnotation,
start_time: float,
end_time: float,
targets: TargetProtocol,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=data.Clip(
recording=clip_annotation.clip.recording,
start_time=start_time,
end_time=end_time,
),
sound_events=[
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if _is_in_subclip(
sound_event_annotation,
targets,
start_time=start_time,
end_time=end_time,
)
],
)

View File

@ -42,8 +42,8 @@ class LabeledDataset(Dataset):
class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx),
start_time=start_time,
end_time=end_time,
start_time=torch.tensor(start_time),
end_time=torch.tensor(end_time),
)
@classmethod

View File

@ -23,13 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
"""
import logging
from collections.abc import Iterable
from functools import partial
from typing import Optional
import numpy as np
import xarray as xr
from loguru import logger
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data
@ -52,8 +52,6 @@ __all__ = [
SIZE_DIMENSION = "dimension"
"""Dimension name for the size heatmap."""
logger = logging.getLogger(__name__)
class LabelConfig(BaseConfig):
"""Configuration parameters for heatmap generation.
@ -137,12 +135,27 @@ def generate_clip_label(
A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
"""
logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events)
)
sound_events = []
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
(
targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
),
sound_events,
spec=spec,
targets=targets,
target_sigma=config.sigma,

View File

@ -58,7 +58,9 @@ class TrainingModule(L.LightningModule):
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
def validation_step( # type: ignore
self, batch: TrainExample, batch_idx: int
) -> ModelOutput:
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
@ -67,6 +69,8 @@ class TrainingModule(L.LightningModule):
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
return outputs
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)

View File

@ -1,12 +1,16 @@
from typing import List, Optional
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
@ -19,25 +23,40 @@ from batdetect2.train.losses import build_loss
__all__ = [
"train",
"build_val_dataset",
"build_train_dataset",
]
def train(
detector: DetectionModel,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
train_examples: List[data.PathLike],
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
) -> None:
config = config or TrainingConfig()
train_dataset = build_dataset(
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
train_dataset = build_train_dataset(
train_examples,
preprocessor,
config=config,
train=True,
)
loss = build_loss(config.loss)
@ -52,7 +71,13 @@ def train(
t_max=config.optimizer.t_max,
)
trainer = Trainer(**config.trainer.model_dump())
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True),
callbacks=callbacks,
num_sanity_val_steps=0,
# enable_model_summary=False,
# enable_progress_bar=False,
)
train_dataloader = DataLoader(
train_dataset,
@ -62,11 +87,9 @@ def train(
val_dataloader = None
if val_examples:
val_dataset = build_dataset(
val_dataset = build_val_dataset(
val_examples,
preprocessor,
config=config,
train=False,
)
val_dataloader = DataLoader(
val_dataset,
@ -81,32 +104,38 @@ def train(
)
def build_dataset(
def build_train_dataset(
examples: List[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
train: bool = True,
):
) -> LabeledDataset:
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
clipper = build_clipper(config.cliping, random=True)
augmentations = None
random_example_source = RandomExampleSource(
examples,
clipper=clipper,
)
if train:
random_example_source = RandomExampleSource(
examples,
clipper=clipper,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
return LabeledDataset(
examples,
clipper=clipper,
augmentation=augmentations,
)
def build_val_dataset(
examples: List[data.PathLike],
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper)

View File

@ -57,8 +57,8 @@ class TrainExample(NamedTuple):
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
start_time: float
end_time: float
start_time: torch.Tensor
end_time: torch.Tensor
class Losses(NamedTuple):

View File

@ -1,4 +1,5 @@
import numpy as np
import xarray as xr
def extend_width(
@ -59,3 +60,10 @@ def adjust_width(
for index in range(dims)
]
return array[tuple(slices)]
def iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)