Run lint fixes

This commit is contained in:
mbsantiago 2025-12-08 17:19:33 +00:00
parent 2563f26ed3
commit 113f438e74
110 changed files with 127 additions and 156 deletions

View File

@ -98,7 +98,7 @@ consult the API documentation in the code.
"""
import warnings
from typing import List, Optional, Tuple
from typing import List, Tuple
import numpy as np
import torch

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Sequence, Tuple
import numpy as np
import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, List, Literal, Optional, Union
from typing import Annotated, List, Literal
import numpy as np
from loguru import logger

View File

@ -1,4 +1,3 @@
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import Field
from soundevent.data import PathLike

View File

@ -1,4 +1,3 @@
from typing import Optional
import numpy as np
import torch

View File

@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Optional, Type, TypeVar
from typing import Any, Type, TypeVar
import yaml
from deepmerge.merger import Merger

View File

@ -13,7 +13,7 @@ format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`.
"""
from typing import Annotated, Optional, Union
from typing import Annotated
from pydantic import Field
from soundevent import data

View File

@ -12,7 +12,7 @@ that meet specific status criteria (e.g., completed, verified, without issues).
"""
from pathlib import Path
from typing import Literal, Optional
from typing import Literal
from uuid import uuid5
from pydantic import Field

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union
from typing import Annotated, List, Literal, Sequence
from pydantic import Field
from soundevent import data

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import List, Optional, Sequence
from typing import List, Sequence
from loguru import logger
from pydantic import Field

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Optional, Tuple
from typing import Tuple
from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Union
from typing import Annotated
from pydantic import Field
from soundevent.data import PathLike
@ -24,7 +24,10 @@ __all__ = [
OutputFormatConfig = Annotated[
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
BatDetect2OutputConfig
| ParquetOutputConfig
| SoundEventOutputConfig
| RawOutputConfig,
Field(discriminator="name"),
]

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import List, Literal, Optional, Sequence, TypedDict
from typing import List, Literal, Sequence, TypedDict
import numpy as np
from soundevent import data

View File

@ -1,6 +1,5 @@
import json
from pathlib import Path
from typing import List, Literal, Optional, Sequence
from typing import List, Literal, Sequence
from uuid import UUID
import numpy as np

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from pathlib import Path
from typing import List, Literal, Optional, Sequence
from typing import List, Literal, Sequence
from uuid import UUID, uuid4
import numpy as np

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Literal, Optional, Sequence
from typing import List, Literal, Sequence
import numpy as np
from soundevent import data, io

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple
from sklearn.model_selection import train_test_split

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union
from typing import Annotated, Dict, List, Literal
from pydantic import Field
from soundevent import data

View File

@ -1,6 +1,6 @@
"""Functions to compute features from predictions."""
from typing import Dict, List, Optional
from typing import Dict, List
import numpy as np

View File

@ -1,7 +1,7 @@
import datetime
import os
from pathlib import Path
from typing import List, Optional, Union
from typing import List
from pydantic import BaseModel, Field, computed_field

View File

@ -1,6 +1,6 @@
"""Post-processing of the output of the model."""
from typing import List, Tuple, Union
from typing import List, Tuple
import numpy as np
import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from pydantic import Field
from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence
from typing import List, NamedTuple, Sequence
import torch
from loguru import logger

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
@ -36,7 +36,7 @@ class Evaluator:
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {}
for task, outputs in zip(self.tasks, eval_outputs):
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
results.update(task.compute_metrics(outputs))
return results
@ -45,7 +45,7 @@ class Evaluator:
self,
eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs):
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
for name, fig in task.generate_plots(outputs):
yield name, fig

View File

@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_train)
tr_acc = (y_pred == y_train).mean()
(y_pred == y_train).mean()
# print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
def check_classes_in_train(gt_list, class_names):
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0
for gt in gt_list:
for cc in gt["class_names"]:
@ -569,7 +569,7 @@ if __name__ == "__main__":
num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class:
print("Classes from the test set are not in the train set.")
assert False
raise AssertionError()
# only need the train data if evaluating Sonobat or Tadarida
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
@ -743,7 +743,7 @@ if __name__ == "__main__":
# check if the class names are the same
if params_bd["class_names"] != class_names:
print("Warning: Class names are not the same as the trained model")
assert False
raise AssertionError()
run_config = {
**bd_args,
@ -753,7 +753,7 @@ if __name__ == "__main__":
preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test):
for gg in gt_test:
pred = du.process_file(
gg["file_path"],
model,

View File

@ -47,7 +47,7 @@ class EvaluationModule(LightningModule):
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
clip_annotations, clip_detections, strict=False
)
]

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
from typing import Annotated, List, Literal, Sequence, Tuple
import numpy as np
from pydantic import Field
@ -94,7 +94,7 @@ def match(
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
prediction.class_scores, strict=False,
)
}
if prediction is not None
@ -563,7 +563,7 @@ def select_optimal_matches(
maximize=True,
)
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold:

View File

@ -7,10 +7,8 @@ from typing import (
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
import numpy as np
from pydantic import Field

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
from typing import Annotated, Callable, Dict, Literal, Sequence
import numpy as np
from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple
import numpy as np

View File

@ -5,9 +5,7 @@ from typing import (
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,
@ -6,9 +5,7 @@ from typing import (
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np

View File

@ -1,4 +1,3 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd

View File

@ -4,10 +4,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt

View File

@ -8,10 +8,8 @@ from typing import (
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
@ -405,7 +403,7 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
*[(index, match.score) for index, match in enumerate(matches)], strict=False
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -1,4 +1,4 @@
from typing import Annotated, Callable, Literal, Sequence, Union
from typing import Annotated, Callable, Literal, Sequence
import pandas as pd
from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence, Union
from typing import Annotated, Optional, Sequence
from pydantic import Field
from soundevent import data
@ -26,7 +26,11 @@ __all__ = [
TaskConfig = Annotated[
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
ClassificationTaskConfig
| DetectionTaskConfig
| ClipDetectionTaskConfig
| ClipClassificationTaskConfig
| TopClassDetectionTaskConfig,
Field(discriminator="name"),
]

View File

@ -4,7 +4,6 @@ from typing import (
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
@ -101,7 +100,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
]
def evaluate_clip(

View File

@ -1,7 +1,7 @@
import argparse
import os
import warnings
from typing import List, Optional
from typing import List
import torch
import torch.utils.data
@ -88,7 +88,7 @@ def select_device(warn=True) -> str:
if warn:
warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training."
"to speed up training.", stacklevel=2
)
return "cpu"

View File

@ -2,7 +2,7 @@ import argparse
import json
import os
from collections import Counter
from typing import List, Optional, Tuple
from typing import List, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
@ -162,7 +162,7 @@ def main():
# change the names of the classes
ip_names = args.input_class_names.split(";")
op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names))
name_dict = dict(zip(ip_names, op_names, strict=False))
# load annotations
data_all = tu.load_set_of_anns(

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence
from typing import List, NamedTuple, Sequence
import torch
from loguru import logger

View File

@ -39,7 +39,7 @@ class InferenceModule(LightningModule):
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections)
for clip, clip_dets in zip(clips, clip_detections, strict=False)
]
return predictions

View File

@ -9,10 +9,8 @@ from typing import (
Dict,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
import numpy as np

View File

@ -26,7 +26,7 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import List, Optional
from typing import List
import torch

View File

@ -14,7 +14,7 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field

View File

@ -1,4 +1,3 @@
from typing import Optional
from soundevent import data

View File

@ -18,7 +18,7 @@ The `Decoder`'s `forward` method is designed to accept skip connection tensors
at each stage.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field
@ -182,7 +182,7 @@ class Decoder(nn.Module):
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
for layer, res in zip(self.layers, residuals[::-1], strict=False):
x = layer(x + res)
return x

View File

@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
from typing import Optional
import torch
from loguru import logger

View File

@ -20,7 +20,7 @@ bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple
from matplotlib.axes import Axes
from soundevent import data, plot
@ -68,6 +68,6 @@ def plot_anchor_points(
position, _ = targets.encode_roi(sound_event)
positions.append(position)
X, Y = zip(*positions)
X, Y = zip(*positions, strict=False)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax

View File

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple
from typing import Iterable, Tuple
from matplotlib.axes import Axes
from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple
import matplotlib.pyplot as plt
import torch

View File

@ -1,6 +1,6 @@
"""General plotting utilities."""
from typing import Optional, Tuple, Union
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np

View File

@ -1,4 +1,3 @@
from typing import Optional
from matplotlib import axes, patches
from soundevent.plot import plot_geometry

View File

@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Sequence
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -36,7 +36,7 @@ def plot_match_gallery(
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
try:
plot_true_positive_match(
tp_match,
@ -53,7 +53,7 @@ def plot_match_gallery(
):
continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
try:
plot_false_positive_match(
fp_match,
@ -70,7 +70,7 @@ def plot_match_gallery(
):
continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
try:
plot_false_negative_match(
fn_match,
@ -87,7 +87,7 @@ def plot_match_gallery(
):
continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
try:
plot_cross_trigger_match(
ct_match,

View File

@ -1,6 +1,6 @@
"""Plot heatmaps"""
from typing import List, Optional, Tuple, Union
from typing import List, Tuple
import numpy as np
import torch

View File

@ -1,6 +1,6 @@
"""Plot functions to visualize detections and spectrograms."""
from typing import List, Optional, Tuple, Union, cast
from typing import List, Tuple, cast
import matplotlib.ticker as tick
import numpy as np

View File

@ -1,4 +1,4 @@
from typing import Optional, Protocol, Tuple, Union
from typing import Protocol, Tuple
from matplotlib.axes import Axes
from soundevent import data, plot

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Tuple
import numpy as np
import seaborn as sns

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import Field
from soundevent import data

View File

@ -1,6 +1,6 @@
"""Decodes extracted detection data into standard soundevent predictions."""
from typing import List, Optional
from typing import List
import numpy as np
from soundevent import data
@ -39,7 +39,7 @@ def to_raw_predictions(
detections.times,
detections.frequencies,
detections.sizes,
detections.features,
detections.features, strict=False,
):
highest_scoring_class = targets.class_names[class_scores.argmax()]

View File

@ -15,7 +15,7 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
from typing import List, Optional
from typing import List
import torch

View File

@ -11,7 +11,7 @@ activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event.
"""
from typing import Tuple, Union
from typing import Tuple
import torch

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import List, Tuple
import torch
from loguru import logger

View File

@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
intermediate features.
"""
from typing import Dict, List, Optional, Union
from typing import Dict, List
import numpy as np
import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Union
from typing import Annotated, Literal
import torch
from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from pydantic import Field
from soundevent.data import PathLike

View File

@ -1,4 +1,3 @@
from typing import Optional
import torch
from loguru import logger

View File

@ -1,6 +1,6 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import Annotated, Callable, Literal, Optional, Union
from typing import Annotated, Callable, Literal
import numpy as np
import torch

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List
from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data

View File

@ -1,5 +1,5 @@
from collections import Counter
from typing import List, Optional
from typing import List
from pydantic import Field, field_validator
from soundevent import data

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, Literal, Optional, Tuple, Union
from typing import Annotated, Literal, Tuple
import numpy as np
from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Tuple
from loguru import logger
from soundevent import data

View File

@ -2,7 +2,7 @@
import warnings
from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
from typing import Annotated, Callable, List, Literal, Tuple
import numpy as np
import torch
@ -394,7 +394,7 @@ class MaskTime(torch.nn.Module):
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
]
return mask_time(spec, masks), clip_annotation
@ -460,7 +460,7 @@ class MaskFrequency(torch.nn.Module):
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
]
return mask_frequency(spec, masks), clip_annotation

View File

@ -107,7 +107,7 @@ class ValidationMetrics(Callback):
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
clip_annotations, clip_detections, strict=False
)
]

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
from lightning.pytorch.callbacks import Callback, ModelCheckpoint

View File

@ -1,4 +1,3 @@
from typing import Optional, Union
from pydantic import Field
from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Sequence
from typing import List, Sequence
import torch
from loguru import logger

View File

@ -6,7 +6,6 @@ the specific multi-channel heatmap formats required by the neural network.
"""
from functools import partial
from typing import Optional
import numpy as np
import torch

View File

@ -1,12 +1,12 @@
from typing import NamedTuple, Optional
from typing import NamedTuple
import torch
from batdetect2.models.types import DetectionModel
from soundevent import data
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset
@ -26,7 +26,7 @@ def train_loop(
learning_rate: float = 1e-4,
):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32)
DataLoader(validation_dataset, batch_size=32)
model.to(device)
@ -36,8 +36,8 @@ def train_loop(
num_epochs * len(train_loader),
)
for epoch in range(num_epochs):
train_loss = train_single_epoch(
for _epoch in range(num_epochs):
train_single_epoch(
model,
train_loader,
optimizer,
@ -60,9 +60,9 @@ def train_single_epoch(
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
batch.detection_heatmap.to(device)
batch.class_heatmap.to(device)
batch.size_heatmap.to(device)
outputs = model(spec)

View File

@ -2,6 +2,10 @@ import argparse
import json
import warnings
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import matplotlib.pyplot as plt
import numpy as np
import torch
@ -9,10 +13,6 @@ import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import batdetect2.utils.plot_utils as pu
from batdetect2.detector import models, parameters
from batdetect2.train import losses

View File

@ -10,7 +10,7 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
else:
print("Split not defined")
assert False
raise AssertionError()
return train_sets, test_sets

View File

@ -2,7 +2,7 @@ import json
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Tuple
import numpy as np

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Tuple
import lightning as L
import torch

View File

@ -18,7 +18,6 @@ The primary entry points are:
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
"""
from typing import Optional
import numpy as np
import torch

View File

@ -1,6 +1,6 @@
"""Types used in the code base."""
from typing import Any, List, NamedTuple, Optional, TypedDict
from typing import Any, List, NamedTuple, TypedDict
import numpy as np
import torch

View File

@ -1,4 +1,4 @@
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
from typing import Generic, List, Protocol, Sequence, TypeVar
from soundevent.data import PathLike

View File

@ -4,7 +4,6 @@ from typing import (
Generic,
Iterable,
List,
Optional,
Protocol,
Sequence,
Tuple,

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol, Sequence
from typing import List, NamedTuple, Protocol, Sequence
import numpy as np
import torch

View File

@ -10,7 +10,7 @@ pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations).
"""
from typing import Optional, Protocol
from typing import Protocol
import numpy as np
import torch

View File

@ -13,7 +13,7 @@ throughout BatDetect2.
"""
from collections.abc import Callable
from typing import List, Optional, Protocol
from typing import List, Protocol
import numpy as np
from soundevent import data

Some files were not shown because too many files have changed in this diff Show More