mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run lint fixes
This commit is contained in:
parent
2563f26ed3
commit
113f438e74
@ -98,7 +98,7 @@ consult the API documentation in the code.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, List, Literal
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
|
||||
@ -13,7 +13,7 @@ format-specific loading function to retrieve the annotations as a standard
|
||||
`soundevent.data.AnnotationSet`.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -12,7 +12,7 @@ that meet specific status criteria (e.g., completed, verified, without issues).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
from uuid import uuid5
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence, Union
|
||||
from typing import Annotated, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -19,7 +19,7 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
@ -24,7 +24,10 @@ __all__ = [
|
||||
|
||||
|
||||
OutputFormatConfig = Annotated[
|
||||
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
|
||||
BatDetect2OutputConfig
|
||||
| ParquetOutputConfig
|
||||
| SoundEventOutputConfig
|
||||
| RawOutputConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence, TypedDict
|
||||
from typing import List, Literal, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence
|
||||
from typing import List, Literal, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence
|
||||
from typing import List, Literal, Sequence
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data, io
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
from typing import Annotated, Dict, List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
from typing import List, NamedTuple, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
@ -36,7 +36,7 @@ class Evaluator:
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||
results.update(task.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
@ -45,7 +45,7 @@ class Evaluator:
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||
for name, fig in task.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
|
||||
@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||
clf.fit(x_train, y_train)
|
||||
y_pred = clf.predict(x_train)
|
||||
tr_acc = (y_pred == y_train).mean()
|
||||
(y_pred == y_train).mean()
|
||||
# print('Train acc', round(tr_acc*100, 2))
|
||||
return clf, un_train_class
|
||||
|
||||
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
|
||||
|
||||
|
||||
def check_classes_in_train(gt_list, class_names):
|
||||
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
num_with_no_class = 0
|
||||
for gt in gt_list:
|
||||
for cc in gt["class_names"]:
|
||||
@ -569,7 +569,7 @@ if __name__ == "__main__":
|
||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||
if total_num_calls == num_with_no_class:
|
||||
print("Classes from the test set are not in the train set.")
|
||||
assert False
|
||||
raise AssertionError()
|
||||
|
||||
# only need the train data if evaluating Sonobat or Tadarida
|
||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||
@ -743,7 +743,7 @@ if __name__ == "__main__":
|
||||
# check if the class names are the same
|
||||
if params_bd["class_names"] != class_names:
|
||||
print("Warning: Class names are not the same as the trained model")
|
||||
assert False
|
||||
raise AssertionError()
|
||||
|
||||
run_config = {
|
||||
**bd_args,
|
||||
@ -753,7 +753,7 @@ if __name__ == "__main__":
|
||||
|
||||
preds_bd = []
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for ii, gg in enumerate(gt_test):
|
||||
for gg in gt_test:
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
|
||||
@ -47,7 +47,7 @@ class EvaluationModule(LightningModule):
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||
from typing import Annotated, List, Literal, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
@ -94,7 +94,7 @@ def match(
|
||||
class_name: score
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
prediction.class_scores, strict=False,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
@ -563,7 +563,7 @@ def select_optimal_matches(
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
|
||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
|
||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||
|
||||
if affinity <= affinity_threshold:
|
||||
|
||||
@ -7,10 +7,8 @@ from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -5,9 +5,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -6,9 +5,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -4,10 +4,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -8,10 +8,8 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -405,7 +403,7 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[(index, match.score) for index, match in enumerate(matches)]
|
||||
*[(index, match.score) for index, match in enumerate(matches)], strict=False
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
from typing import Annotated, Callable, Literal, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional, Sequence, Union
|
||||
from typing import Annotated, Optional, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -26,7 +26,11 @@ __all__ = [
|
||||
|
||||
|
||||
TaskConfig = Annotated[
|
||||
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
|
||||
ClassificationTaskConfig
|
||||
| DetectionTaskConfig
|
||||
| ClipDetectionTaskConfig
|
||||
| ClipClassificationTaskConfig
|
||||
| TopClassDetectionTaskConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import (
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
@ -101,7 +100,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
|
||||
]
|
||||
|
||||
def evaluate_clip(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
@ -88,7 +88,7 @@ def select_device(warn=True) -> str:
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||
"to speed up training."
|
||||
"to speed up training.", stacklevel=2
|
||||
)
|
||||
|
||||
return "cpu"
|
||||
|
||||
@ -2,7 +2,7 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
@ -162,7 +162,7 @@ def main():
|
||||
# change the names of the classes
|
||||
ip_names = args.input_class_names.split(";")
|
||||
op_names = args.output_class_names.split(";")
|
||||
name_dict = dict(zip(ip_names, op_names))
|
||||
name_dict = dict(zip(ip_names, op_names, strict=False))
|
||||
|
||||
# load annotations
|
||||
data_all = tu.load_set_of_anns(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
from typing import List, NamedTuple, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -39,7 +39,7 @@ class InferenceModule(LightningModule):
|
||||
targets=self.model.targets,
|
||||
),
|
||||
)
|
||||
for clip, clip_dets in zip(clips, clip_detections)
|
||||
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
||||
]
|
||||
|
||||
return predictions
|
||||
|
||||
@ -9,10 +9,8 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -26,7 +26,7 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
||||
provided here.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||
module based on the provided configuration.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||
at each stage.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@ -182,7 +182,7 @@ class Decoder(nn.Module):
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
for layer, res in zip(self.layers, residuals[::-1], strict=False):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
||||
@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -20,7 +20,7 @@ bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||
provided.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
@ -68,6 +68,6 @@ def plot_anchor_points(
|
||||
position, _ = targets.encode_roi(sound_event)
|
||||
positions.append(position)
|
||||
|
||||
X, Y = zip(*positions)
|
||||
X, Y = zip(*positions, strict=False)
|
||||
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
|
||||
return ax
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""General plotting utilities."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from matplotlib import axes, patches
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
@ -36,7 +36,7 @@ def plot_match_gallery(
|
||||
sharey="row",
|
||||
)
|
||||
|
||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
|
||||
try:
|
||||
plot_true_positive_match(
|
||||
tp_match,
|
||||
@ -53,7 +53,7 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
|
||||
try:
|
||||
plot_false_positive_match(
|
||||
fp_match,
|
||||
@ -70,7 +70,7 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
|
||||
try:
|
||||
plot_false_negative_match(
|
||||
fn_match,
|
||||
@ -87,7 +87,7 @@ def plot_match_gallery(
|
||||
):
|
||||
continue
|
||||
|
||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
|
||||
try:
|
||||
plot_cross_trigger_match(
|
||||
ct_match,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Plot heatmaps"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Plot functions to visualize detections and spectrograms."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Protocol, Tuple, Union
|
||||
from typing import Protocol, Tuple
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
@ -39,7 +39,7 @@ def to_raw_predictions(
|
||||
detections.times,
|
||||
detections.frequencies,
|
||||
detections.sizes,
|
||||
detections.features,
|
||||
detections.features, strict=False,
|
||||
):
|
||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ activations that have lower scores than a local maximum. This helps prevent
|
||||
multiple, overlapping detections originating from the same sound event.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
from typing import Annotated, Callable, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import Counter
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
||||
*geometric* aspect of target definition from *semantic* classification.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Callable, List, Literal, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -394,7 +394,7 @@ class MaskTime(torch.nn.Module):
|
||||
size=num_masks,
|
||||
)
|
||||
masks = [
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||
]
|
||||
return mask_time(spec, masks), clip_annotation
|
||||
|
||||
@ -460,7 +460,7 @@ class MaskFrequency(torch.nn.Module):
|
||||
size=num_masks,
|
||||
)
|
||||
masks = [
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||
]
|
||||
return mask_frequency(spec, masks), clip_annotation
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ class ValidationMetrics(Callback):
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -6,7 +6,6 @@ the specific multi-channel heatmap formats required by the neural network.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ def train_loop(
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
@ -36,8 +36,8 @@ def train_loop(
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
for _epoch in range(num_epochs):
|
||||
train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
@ -60,9 +60,9 @@ def train_single_epoch(
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
batch.detection_heatmap.to(device)
|
||||
batch.class_heatmap.to(device)
|
||||
batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
|
||||
@ -2,6 +2,10 @@ import argparse
|
||||
import json
|
||||
import warnings
|
||||
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_split as ts
|
||||
import batdetect2.train.train_utils as tu
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -9,10 +13,6 @@ import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_split as ts
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import models, parameters
|
||||
from batdetect2.train import losses
|
||||
|
||||
@ -10,7 +10,7 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
||||
else:
|
||||
print("Split not defined")
|
||||
assert False
|
||||
raise AssertionError()
|
||||
|
||||
return train_sets, test_sets
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
from typing import Dict, Generator, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
|
||||
@ -18,7 +18,6 @@ The primary entry points are:
|
||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional, TypedDict
|
||||
from typing import Any, List, NamedTuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
|
||||
from typing import Generic, List, Protocol, Sequence, TypeVar
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import (
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
|
||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
||||
from typing import List, NamedTuple, Protocol, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -10,7 +10,7 @@ pipeline can interact consistently, regardless of the specific underlying
|
||||
implementation (e.g., different libraries or custom configurations).
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -13,7 +13,7 @@ throughout BatDetect2.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import List, Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user