Run lint fixes

This commit is contained in:
mbsantiago 2025-12-08 17:19:33 +00:00
parent 2563f26ed3
commit 113f438e74
110 changed files with 127 additions and 156 deletions

View File

@ -98,7 +98,7 @@ consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal
import numpy as np import numpy as np
from loguru import logger from loguru import logger

View File

@ -1,4 +1,3 @@
from typing import Optional
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
from loguru import logger from loguru import logger

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
from loguru import logger from loguru import logger

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike

View File

@ -1,4 +1,3 @@
from typing import Optional
import numpy as np import numpy as np
import torch import torch

View File

@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections. configuration sections.
""" """
from typing import Any, Optional, Type, TypeVar from typing import Any, Type, TypeVar
import yaml import yaml
from deepmerge.merger import Merger from deepmerge.merger import Merger

View File

@ -13,7 +13,7 @@ format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`. `soundevent.data.AnnotationSet`.
""" """
from typing import Annotated, Optional, Union from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -12,7 +12,7 @@ that meet specific status criteria (e.g., completed, verified, without issues).
""" """
from pathlib import Path from pathlib import Path
from typing import Literal, Optional from typing import Literal
from uuid import uuid5 from uuid import uuid5
from pydantic import Field from pydantic import Field

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union from typing import Annotated, List, Literal, Sequence
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -19,7 +19,7 @@ The core components are:
""" """
from pathlib import Path from pathlib import Path
from typing import List, Optional, Sequence from typing import List, Sequence
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field

View File

@ -1,5 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Tuple from typing import Tuple
from soundevent import data from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Union from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
@ -24,7 +24,10 @@ __all__ = [
OutputFormatConfig = Annotated[ OutputFormatConfig = Annotated[
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig, BatDetect2OutputConfig
| ParquetOutputConfig
| SoundEventOutputConfig
| RawOutputConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Sequence, TypedDict from typing import List, Literal, Sequence, TypedDict
import numpy as np import numpy as np
from soundevent import data from soundevent import data

View File

@ -1,6 +1,5 @@
import json
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Sequence from typing import List, Literal, Sequence
from uuid import UUID from uuid import UUID
import numpy as np import numpy as np

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Sequence from typing import List, Literal, Sequence
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import numpy as np import numpy as np

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import List, Literal, Optional, Sequence from typing import List, Literal, Sequence
import numpy as np import numpy as np
from soundevent import data, io from soundevent import data, io

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Tuple
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split

View File

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union from typing import Annotated, Dict, List, Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -1,6 +1,6 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, List, Optional from typing import Dict, List
import numpy as np import numpy as np

View File

@ -1,7 +1,7 @@
import datetime import datetime
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field

View File

@ -1,6 +1,6 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple, Union from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import List
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence from typing import List, NamedTuple, Sequence
import torch import torch
from loguru import logger from loguru import logger

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterable, List, Sequence, Tuple
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
@ -36,7 +36,7 @@ class Evaluator:
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]: def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {} results = {}
for task, outputs in zip(self.tasks, eval_outputs): for task, outputs in zip(self.tasks, eval_outputs, strict=False):
results.update(task.compute_metrics(outputs)) results.update(task.compute_metrics(outputs))
return results return results
@ -45,7 +45,7 @@ class Evaluator:
self, self,
eval_outputs: List[Any], eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs): for task, outputs in zip(self.tasks, eval_outputs, strict=False):
for name, fig in task.generate_plots(outputs): for name, fig in task.generate_plots(outputs):
yield name, fig yield name, fig

View File

@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1) clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train) clf.fit(x_train, y_train)
y_pred = clf.predict(x_train) y_pred = clf.predict(x_train)
tr_acc = (y_pred == y_train).mean() (y_pred == y_train).mean()
# print('Train acc', round(tr_acc*100, 2)) # print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class return clf, un_train_class
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
def check_classes_in_train(gt_list, class_names): def check_classes_in_train(gt_list, class_names):
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list]) np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0 num_with_no_class = 0
for gt in gt_list: for gt in gt_list:
for cc in gt["class_names"]: for cc in gt["class_names"]:
@ -569,7 +569,7 @@ if __name__ == "__main__":
num_with_no_class = check_classes_in_train(gt_test, class_names) num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class: if total_num_calls == num_with_no_class:
print("Classes from the test set are not in the train set.") print("Classes from the test set are not in the train set.")
assert False raise AssertionError()
# only need the train data if evaluating Sonobat or Tadarida # only need the train data if evaluating Sonobat or Tadarida
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "": if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
@ -743,7 +743,7 @@ if __name__ == "__main__":
# check if the class names are the same # check if the class names are the same
if params_bd["class_names"] != class_names: if params_bd["class_names"] != class_names:
print("Warning: Class names are not the same as the trained model") print("Warning: Class names are not the same as the trained model")
assert False raise AssertionError()
run_config = { run_config = {
**bd_args, **bd_args,
@ -753,7 +753,7 @@ if __name__ == "__main__":
preds_bd = [] preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test): for gg in gt_test:
pred = du.process_file( pred = du.process_file(
gg["file_path"], gg["file_path"],
model, model,

View File

@ -47,7 +47,7 @@ class EvaluationModule(LightningModule):
), ),
) )
for clip_annotation, clip_dets in zip( for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections clip_annotations, clip_detections, strict=False
) )
] ]

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union from typing import Annotated, List, Literal, Sequence, Tuple
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
@ -94,7 +94,7 @@ def match(
class_name: score class_name: score
for class_name, score in zip( for class_name, score in zip(
targets.class_names, targets.class_names,
prediction.class_scores, prediction.class_scores, strict=False,
) )
} }
if prediction is not None if prediction is not None
@ -563,7 +563,7 @@ def select_optimal_matches(
maximize=True, maximize=True,
) )
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns): for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
affinity = float(affinity_matrix[gt_idx, pred_idx]) affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold: if affinity <= affinity_threshold:

View File

@ -7,10 +7,8 @@ from typing import (
List, List,
Literal, Literal,
Mapping, Mapping,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import numpy as np import numpy as np

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union from typing import Annotated, Callable, Dict, Literal, Sequence, Set
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union from typing import Annotated, Callable, Dict, Literal, Sequence
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Tuple
import numpy as np import numpy as np

View File

@ -5,9 +5,7 @@ from typing import (
Dict, Dict,
List, List,
Literal, Literal,
Optional,
Sequence, Sequence,
Union,
) )
import numpy as np import numpy as np

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Annotated, Annotated,
@ -6,9 +5,7 @@ from typing import (
Dict, Dict,
List, List,
Literal, Literal,
Optional,
Sequence, Sequence,
Union,
) )
import numpy as np import numpy as np

View File

@ -1,4 +1,3 @@
from typing import Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import pandas as pd import pandas as pd

View File

@ -4,10 +4,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -8,10 +8,8 @@ from typing import (
Iterable, Iterable,
List, List,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -405,7 +403,7 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
return matches return matches
indices, pred_scores = zip( indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)] *[(index, match.score) for index, match in enumerate(matches)], strict=False
) )
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -1,4 +1,4 @@
from typing import Annotated, Callable, Literal, Sequence, Union from typing import Annotated, Callable, Literal, Sequence
import pandas as pd import pandas as pd
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence, Union from typing import Annotated, Optional, Sequence
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -26,7 +26,11 @@ __all__ = [
TaskConfig = Annotated[ TaskConfig = Annotated[
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig, ClassificationTaskConfig
| DetectionTaskConfig
| ClipDetectionTaskConfig
| ClipClassificationTaskConfig
| TopClassDetectionTaskConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -4,7 +4,6 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
Optional,
Sequence, Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
@ -101,7 +100,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
) -> List[T_Output]: ) -> List[T_Output]:
return [ return [
self.evaluate_clip(clip_annotation, preds) self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions) for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
] ]
def evaluate_clip( def evaluate_clip(

View File

@ -1,7 +1,7 @@
import argparse import argparse
import os import os
import warnings import warnings
from typing import List, Optional from typing import List
import torch import torch
import torch.utils.data import torch.utils.data
@ -88,7 +88,7 @@ def select_device(warn=True) -> str:
if warn: if warn:
warnings.warn( warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU " "No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training." "to speed up training.", stacklevel=2
) )
return "cpu" return "cpu"

View File

@ -2,7 +2,7 @@ import argparse
import json import json
import os import os
from collections import Counter from collections import Counter
from typing import List, Optional, Tuple from typing import List, Tuple
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedGroupKFold from sklearn.model_selection import StratifiedGroupKFold
@ -162,7 +162,7 @@ def main():
# change the names of the classes # change the names of the classes
ip_names = args.input_class_names.split(";") ip_names = args.input_class_names.split(";")
op_names = args.output_class_names.split(";") op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names)) name_dict = dict(zip(ip_names, op_names, strict=False))
# load annotations # load annotations
data_all = tu.load_set_of_anns( data_all = tu.load_set_of_anns(

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence from typing import List, NamedTuple, Sequence
import torch import torch
from loguru import logger from loguru import logger

View File

@ -39,7 +39,7 @@ class InferenceModule(LightningModule):
targets=self.model.targets, targets=self.model.targets,
), ),
) )
for clip, clip_dets in zip(clips, clip_detections) for clip, clip_dets in zip(clips, clip_detections, strict=False)
] ]
return predictions return predictions

View File

@ -9,10 +9,8 @@ from typing import (
Dict, Dict,
Generic, Generic,
Literal, Literal,
Optional,
Protocol, Protocol,
TypeVar, TypeVar,
Union,
) )
import numpy as np import numpy as np

View File

@ -26,7 +26,7 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here. provided here.
""" """
from typing import List, Optional from typing import List
import torch import torch

View File

@ -14,7 +14,7 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration. module based on the provided configuration.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,3 @@
from typing import Optional
from soundevent import data from soundevent import data

View File

@ -18,7 +18,7 @@ The `Decoder`'s `forward` method is designed to accept skip connection tensors
at each stage. at each stage.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field
@ -182,7 +182,7 @@ class Decoder(nn.Module):
f"but got {len(residuals)}." f"but got {len(residuals)}."
) )
for layer, res in zip(self.layers, residuals[::-1]): for layer, res in zip(self.layers, residuals[::-1], strict=False):
x = layer(x + res) x = layer(x + res)
return x return x

View File

@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
""" """
from typing import Optional
import torch import torch
from loguru import logger from loguru import logger

View File

@ -20,7 +20,7 @@ bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided. provided.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Tuple
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
@ -68,6 +68,6 @@ def plot_anchor_points(
position, _ = targets.encode_roi(sound_event) position, _ = targets.encode_roi(sound_event)
positions.append(position) positions.append(position)
X, Y = zip(*positions) X, Y = zip(*positions, strict=False)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha) ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax return ax

View File

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple from typing import Iterable, Tuple
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch

View File

@ -1,6 +1,6 @@
"""General plotting utilities.""" """General plotting utilities."""
from typing import Optional, Tuple, Union from typing import Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np

View File

@ -1,4 +1,3 @@
from typing import Optional
from matplotlib import axes, patches from matplotlib import axes, patches
from soundevent.plot import plot_geometry from soundevent.plot import plot_geometry

View File

@ -1,4 +1,4 @@
from typing import Optional, Sequence from typing import Sequence
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure
@ -36,7 +36,7 @@ def plot_match_gallery(
sharey="row", sharey="row",
) )
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]): for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
try: try:
plot_true_positive_match( plot_true_positive_match(
tp_match, tp_match,
@ -53,7 +53,7 @@ def plot_match_gallery(
): ):
continue continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]): for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
try: try:
plot_false_positive_match( plot_false_positive_match(
fp_match, fp_match,
@ -70,7 +70,7 @@ def plot_match_gallery(
): ):
continue continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]): for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
try: try:
plot_false_negative_match( plot_false_negative_match(
fn_match, fn_match,
@ -87,7 +87,7 @@ def plot_match_gallery(
): ):
continue continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]): for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
try: try:
plot_cross_trigger_match( plot_cross_trigger_match(
ct_match, ct_match,

View File

@ -1,6 +1,6 @@
"""Plot heatmaps""" """Plot heatmaps"""
from typing import List, Optional, Tuple, Union from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch

View File

@ -1,6 +1,6 @@
"""Plot functions to visualize detections and spectrograms.""" """Plot functions to visualize detections and spectrograms."""
from typing import List, Optional, Tuple, Union, cast from typing import List, Tuple, cast
import matplotlib.ticker as tick import matplotlib.ticker as tick
import numpy as np import numpy as np

View File

@ -1,4 +1,4 @@
from typing import Optional, Protocol, Tuple, Union from typing import Protocol, Tuple
from matplotlib.axes import Axes from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Union from typing import Dict, Tuple
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -1,6 +1,6 @@
"""Decodes extracted detection data into standard soundevent predictions.""" """Decodes extracted detection data into standard soundevent predictions."""
from typing import List, Optional from typing import List
import numpy as np import numpy as np
from soundevent import data from soundevent import data
@ -39,7 +39,7 @@ def to_raw_predictions(
detections.times, detections.times,
detections.frequencies, detections.frequencies,
detections.sizes, detections.sizes,
detections.features, detections.features, strict=False,
): ):
highest_scoring_class = targets.class_names[class_scores.argmax()] highest_scoring_class = targets.class_names[class_scores.argmax()]

View File

@ -15,7 +15,7 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`. all extracted information into a structured `xarray.Dataset`.
""" """
from typing import List, Optional from typing import List
import torch import torch

View File

@ -11,7 +11,7 @@ activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event. multiple, overlapping detections originating from the same sound event.
""" """
from typing import Tuple, Union from typing import Tuple
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union from typing import List, Tuple
import torch import torch
from loguru import logger from loguru import logger

View File

@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
intermediate features. intermediate features.
""" """
from typing import Dict, List, Optional, Union from typing import Dict, List
import numpy as np import numpy as np
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Union from typing import Annotated, Literal
import torch import torch
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import List
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike

View File

@ -1,4 +1,3 @@
from typing import Optional
import torch import torch
from loguru import logger from loguru import logger

View File

@ -1,6 +1,6 @@
"""Computes spectrograms from audio waveforms with configurable parameters.""" """Computes spectrograms from audio waveforms with configurable parameters."""
from typing import Annotated, Callable, Literal, Optional, Union from typing import Annotated, Callable, Literal
import numpy as np import numpy as np
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional from typing import Dict, List
from pydantic import Field, PrivateAttr, computed_field, model_validator from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data from soundevent import data

View File

@ -1,5 +1,5 @@
from collections import Counter from collections import Counter
from typing import List, Optional from typing import List
from pydantic import Field, field_validator from pydantic import Field, field_validator
from soundevent import data from soundevent import data

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification. *geometric* aspect of target definition from *semantic* classification.
""" """
from typing import Annotated, Literal, Optional, Tuple, Union from typing import Annotated, Literal, Tuple
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Tuple
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data

View File

@ -2,7 +2,7 @@
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union from typing import Annotated, Callable, List, Literal, Tuple
import numpy as np import numpy as np
import torch import torch
@ -394,7 +394,7 @@ class MaskTime(torch.nn.Module):
size=num_masks, size=num_masks,
) )
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
] ]
return mask_time(spec, masks), clip_annotation return mask_time(spec, masks), clip_annotation
@ -460,7 +460,7 @@ class MaskFrequency(torch.nn.Module):
size=num_masks, size=num_masks,
) )
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
] ]
return mask_frequency(spec, masks), clip_annotation return mask_frequency(spec, masks), clip_annotation

View File

@ -107,7 +107,7 @@ class ValidationMetrics(Callback):
), ),
) )
for clip_annotation, clip_dets in zip( for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections clip_annotations, clip_detections, strict=False
) )
] ]

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint

View File

@ -1,4 +1,3 @@
from typing import Optional, Union
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Sequence from typing import List, Sequence
import torch import torch
from loguru import logger from loguru import logger

View File

@ -6,7 +6,6 @@ the specific multi-channel heatmap formats required by the neural network.
""" """
from functools import partial from functools import partial
from typing import Optional
import numpy as np import numpy as np
import torch import torch

View File

@ -1,12 +1,12 @@
from typing import NamedTuple, Optional from typing import NamedTuple
import torch import torch
from batdetect2.models.types import DetectionModel
from soundevent import data from soundevent import data
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset from batdetect2.train.dataset import LabeledDataset
@ -26,7 +26,7 @@ def train_loop(
learning_rate: float = 1e-4, learning_rate: float = 1e-4,
): ):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32) DataLoader(validation_dataset, batch_size=32)
model.to(device) model.to(device)
@ -36,8 +36,8 @@ def train_loop(
num_epochs * len(train_loader), num_epochs * len(train_loader),
) )
for epoch in range(num_epochs): for _epoch in range(num_epochs):
train_loss = train_single_epoch( train_single_epoch(
model, model,
train_loader, train_loader,
optimizer, optimizer,
@ -60,9 +60,9 @@ def train_single_epoch(
optimizer.zero_grad() optimizer.zero_grad()
spec = batch.spec.to(device) spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device) batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device) batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device) batch.size_heatmap.to(device)
outputs = model(spec) outputs = model(spec)

View File

@ -2,6 +2,10 @@ import argparse
import json import json
import warnings import warnings
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
@ -9,10 +13,6 @@ import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2.detector import models, parameters from batdetect2.detector import models, parameters
from batdetect2.train import losses from batdetect2.train import losses

View File

@ -10,7 +10,7 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra) train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
else: else:
print("Split not defined") print("Split not defined")
assert False raise AssertionError()
return train_sets, test_sets return train_sets, test_sets

View File

@ -2,7 +2,7 @@ import json
import sys import sys
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple from typing import Dict, Generator, List, Tuple
import numpy as np import numpy as np

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Tuple
import lightning as L import lightning as L
import torch import torch

View File

@ -18,7 +18,6 @@ The primary entry points are:
- `LossConfig`: The Pydantic model for configuring loss weights and parameters. - `LossConfig`: The Pydantic model for configuring loss weights and parameters.
""" """
from typing import Optional
import numpy as np import numpy as np
import torch import torch

View File

@ -1,6 +1,6 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import Any, List, NamedTuple, Optional, TypedDict from typing import Any, List, NamedTuple, TypedDict
import numpy as np import numpy as np
import torch import torch

View File

@ -1,4 +1,4 @@
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar from typing import Generic, List, Protocol, Sequence, TypeVar
from soundevent.data import PathLike from soundevent.data import PathLike

View File

@ -4,7 +4,6 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
Optional,
Protocol, Protocol,
Sequence, Sequence,
Tuple, Tuple,

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol, Sequence from typing import List, NamedTuple, Protocol, Sequence
import numpy as np import numpy as np
import torch import torch

View File

@ -10,7 +10,7 @@ pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations). implementation (e.g., different libraries or custom configurations).
""" """
from typing import Optional, Protocol from typing import Protocol
import numpy as np import numpy as np
import torch import torch

View File

@ -13,7 +13,7 @@ throughout BatDetect2.
""" """
from collections.abc import Callable from collections.abc import Callable
from typing import List, Optional, Protocol from typing import List, Protocol
import numpy as np import numpy as np
from soundevent import data from soundevent import data

Some files were not shown because too many files have changed in this diff Show More