mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run lint fixes
This commit is contained in:
parent
2563f26ed3
commit
113f438e74
@ -98,7 +98,7 @@ consult the API documentation in the code.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -8,7 +8,7 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Optional, Type, TypeVar
|
from typing import Any, Type, TypeVar
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from deepmerge.merger import Merger
|
from deepmerge.merger import Merger
|
||||||
|
|||||||
@ -13,7 +13,7 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
`soundevent.data.AnnotationSet`.
|
`soundevent.data.AnnotationSet`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -12,7 +12,7 @@ that meet specific status criteria (e.g., completed, verified, without issues).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
from uuid import uuid5
|
from uuid import uuid5
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Annotated, List, Literal, Sequence, Union
|
from typing import Annotated, List, Literal, Sequence
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Sequence
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
@ -24,7 +24,10 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
OutputFormatConfig = Annotated[
|
OutputFormatConfig = Annotated[
|
||||||
BatDetect2OutputConfig | ParquetOutputConfig | SoundEventOutputConfig | RawOutputConfig,
|
BatDetect2OutputConfig
|
||||||
|
| ParquetOutputConfig
|
||||||
|
| SoundEventOutputConfig
|
||||||
|
| RawOutputConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Sequence, TypedDict
|
from typing import List, Literal, Sequence, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Sequence
|
from typing import List, Literal, Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Sequence
|
from typing import List, Literal, Sequence
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Sequence
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Dict, List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Functions to compute features from predictions."""
|
"""Functions to compute features from predictions."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Literal, Optional, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, NamedTuple, Optional, Sequence
|
from typing import List, NamedTuple, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -36,7 +36,7 @@ class Evaluator:
|
|||||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for task, outputs in zip(self.tasks, eval_outputs):
|
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||||
results.update(task.compute_metrics(outputs))
|
results.update(task.compute_metrics(outputs))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -45,7 +45,7 @@ class Evaluator:
|
|||||||
self,
|
self,
|
||||||
eval_outputs: List[Any],
|
eval_outputs: List[Any],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
for task, outputs in zip(self.tasks, eval_outputs):
|
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||||
for name, fig in task.generate_plots(outputs):
|
for name, fig in task.generate_plots(outputs):
|
||||||
yield name, fig
|
yield name, fig
|
||||||
|
|
||||||
|
|||||||
@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
|||||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||||
clf.fit(x_train, y_train)
|
clf.fit(x_train, y_train)
|
||||||
y_pred = clf.predict(x_train)
|
y_pred = clf.predict(x_train)
|
||||||
tr_acc = (y_pred == y_train).mean()
|
(y_pred == y_train).mean()
|
||||||
# print('Train acc', round(tr_acc*100, 2))
|
# print('Train acc', round(tr_acc*100, 2))
|
||||||
return clf, un_train_class
|
return clf, un_train_class
|
||||||
|
|
||||||
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
|
|||||||
|
|
||||||
|
|
||||||
def check_classes_in_train(gt_list, class_names):
|
def check_classes_in_train(gt_list, class_names):
|
||||||
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||||
num_with_no_class = 0
|
num_with_no_class = 0
|
||||||
for gt in gt_list:
|
for gt in gt_list:
|
||||||
for cc in gt["class_names"]:
|
for cc in gt["class_names"]:
|
||||||
@ -569,7 +569,7 @@ if __name__ == "__main__":
|
|||||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||||
if total_num_calls == num_with_no_class:
|
if total_num_calls == num_with_no_class:
|
||||||
print("Classes from the test set are not in the train set.")
|
print("Classes from the test set are not in the train set.")
|
||||||
assert False
|
raise AssertionError()
|
||||||
|
|
||||||
# only need the train data if evaluating Sonobat or Tadarida
|
# only need the train data if evaluating Sonobat or Tadarida
|
||||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||||
@ -743,7 +743,7 @@ if __name__ == "__main__":
|
|||||||
# check if the class names are the same
|
# check if the class names are the same
|
||||||
if params_bd["class_names"] != class_names:
|
if params_bd["class_names"] != class_names:
|
||||||
print("Warning: Class names are not the same as the trained model")
|
print("Warning: Class names are not the same as the trained model")
|
||||||
assert False
|
raise AssertionError()
|
||||||
|
|
||||||
run_config = {
|
run_config = {
|
||||||
**bd_args,
|
**bd_args,
|
||||||
@ -753,7 +753,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
preds_bd = []
|
preds_bd = []
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
for ii, gg in enumerate(gt_test):
|
for gg in gt_test:
|
||||||
pred = du.process_file(
|
pred = du.process_file(
|
||||||
gg["file_path"],
|
gg["file_path"],
|
||||||
model,
|
model,
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class EvaluationModule(LightningModule):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
for clip_annotation, clip_dets in zip(
|
for clip_annotation, clip_dets in zip(
|
||||||
clip_annotations, clip_detections
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
from typing import Annotated, List, Literal, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -94,7 +94,7 @@ def match(
|
|||||||
class_name: score
|
class_name: score
|
||||||
for class_name, score in zip(
|
for class_name, score in zip(
|
||||||
targets.class_names,
|
targets.class_names,
|
||||||
prediction.class_scores,
|
prediction.class_scores, strict=False,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if prediction is not None
|
if prediction is not None
|
||||||
@ -563,7 +563,7 @@ def select_optimal_matches(
|
|||||||
maximize=True,
|
maximize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
|
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns, strict=False):
|
||||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||||
|
|
||||||
if affinity <= affinity_threshold:
|
if affinity <= affinity_threshold:
|
||||||
|
|||||||
@ -7,10 +7,8 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
from typing import Annotated, Callable, Dict, Literal, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@ -5,9 +5,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -6,9 +5,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|||||||
@ -3,10 +3,8 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|||||||
@ -3,10 +3,8 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|||||||
@ -3,10 +3,8 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
@ -4,10 +4,8 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|||||||
@ -8,10 +8,8 @@ from typing import (
|
|||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -405,7 +403,7 @@ def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
indices, pred_scores = zip(
|
||||||
*[(index, match.score) for index, match in enumerate(matches)]
|
*[(index, match.score) for index, match in enumerate(matches)], strict=False
|
||||||
)
|
)
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
from typing import Annotated, Callable, Literal, Sequence
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Optional, Sequence, Union
|
from typing import Annotated, Optional, Sequence
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -26,7 +26,11 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
TaskConfig = Annotated[
|
TaskConfig = Annotated[
|
||||||
ClassificationTaskConfig | DetectionTaskConfig | ClipDetectionTaskConfig | ClipClassificationTaskConfig | TopClassDetectionTaskConfig,
|
ClassificationTaskConfig
|
||||||
|
| DetectionTaskConfig
|
||||||
|
| ClipDetectionTaskConfig
|
||||||
|
| ClipClassificationTaskConfig
|
||||||
|
| TopClassDetectionTaskConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -101,7 +100,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
) -> List[T_Output]:
|
) -> List[T_Output]:
|
||||||
return [
|
return [
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
for clip_annotation, preds in zip(clip_annotations, predictions, strict=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
@ -88,7 +88,7 @@ def select_device(warn=True) -> str:
|
|||||||
if warn:
|
if warn:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||||
"to speed up training."
|
"to speed up training.", stacklevel=2
|
||||||
)
|
)
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.model_selection import StratifiedGroupKFold
|
from sklearn.model_selection import StratifiedGroupKFold
|
||||||
@ -162,7 +162,7 @@ def main():
|
|||||||
# change the names of the classes
|
# change the names of the classes
|
||||||
ip_names = args.input_class_names.split(";")
|
ip_names = args.input_class_names.split(";")
|
||||||
op_names = args.output_class_names.split(";")
|
op_names = args.output_class_names.split(";")
|
||||||
name_dict = dict(zip(ip_names, op_names))
|
name_dict = dict(zip(ip_names, op_names, strict=False))
|
||||||
|
|
||||||
# load annotations
|
# load annotations
|
||||||
data_all = tu.load_set_of_anns(
|
data_all = tu.load_set_of_anns(
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, NamedTuple, Optional, Sequence
|
from typing import List, NamedTuple, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class InferenceModule(LightningModule):
|
|||||||
targets=self.model.targets,
|
targets=self.model.targets,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for clip, clip_dets in zip(clips, clip_detections)
|
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|||||||
@ -9,10 +9,8 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -26,7 +26,7 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
|||||||
provided here.
|
provided here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
|||||||
module based on the provided configuration.
|
module based on the provided configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
|||||||
at each stage.
|
at each stage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -182,7 +182,7 @@ class Decoder(nn.Module):
|
|||||||
f"but got {len(residuals)}."
|
f"but got {len(residuals)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer, res in zip(self.layers, residuals[::-1]):
|
for layer, res in zip(self.layers, residuals[::-1], strict=False):
|
||||||
x = layer(x + res)
|
x = layer(x + res)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -14,7 +14,6 @@ logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
|||||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -20,7 +20,7 @@ bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
|||||||
provided.
|
provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
@ -68,6 +68,6 @@ def plot_anchor_points(
|
|||||||
position, _ = targets.encode_roi(sound_event)
|
position, _ = targets.encode_roi(sound_event)
|
||||||
positions.append(position)
|
positions.append(position)
|
||||||
|
|
||||||
X, Y = zip(*positions)
|
X, Y = zip(*positions, strict=False)
|
||||||
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
|
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""General plotting utilities."""
|
"""General plotting utilities."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from matplotlib import axes, patches
|
from matplotlib import axes, patches
|
||||||
from soundevent.plot import plot_geometry
|
from soundevent.plot import plot_geometry
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
@ -36,7 +36,7 @@ def plot_match_gallery(
|
|||||||
sharey="row",
|
sharey="row",
|
||||||
)
|
)
|
||||||
|
|
||||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples], strict=False):
|
||||||
try:
|
try:
|
||||||
plot_true_positive_match(
|
plot_true_positive_match(
|
||||||
tp_match,
|
tp_match,
|
||||||
@ -53,7 +53,7 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples], strict=False):
|
||||||
try:
|
try:
|
||||||
plot_false_positive_match(
|
plot_false_positive_match(
|
||||||
fp_match,
|
fp_match,
|
||||||
@ -70,7 +70,7 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples], strict=False):
|
||||||
try:
|
try:
|
||||||
plot_false_negative_match(
|
plot_false_negative_match(
|
||||||
fn_match,
|
fn_match,
|
||||||
@ -87,7 +87,7 @@ def plot_match_gallery(
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples], strict=False):
|
||||||
try:
|
try:
|
||||||
plot_cross_trigger_match(
|
plot_cross_trigger_match(
|
||||||
ct_match,
|
ct_match,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Plot heatmaps"""
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Plot functions to visualize detections and spectrograms."""
|
"""Plot functions to visualize detections and spectrograms."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import List, Tuple, cast
|
||||||
|
|
||||||
import matplotlib.ticker as tick
|
import matplotlib.ticker as tick
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Protocol, Tuple, Union
|
from typing import Protocol, Tuple
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -39,7 +39,7 @@ def to_raw_predictions(
|
|||||||
detections.times,
|
detections.times,
|
||||||
detections.frequencies,
|
detections.frequencies,
|
||||||
detections.sizes,
|
detections.sizes,
|
||||||
detections.features,
|
detections.features, strict=False,
|
||||||
):
|
):
|
||||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ precise time-frequency location of each detection. The final output aggregates
|
|||||||
all extracted information into a structured `xarray.Dataset`.
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ activations that have lower scores than a local maximum. This helps prevent
|
|||||||
multiple, overlapping detections originating from the same sound event.
|
multiple, overlapping detections originating from the same sound event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple, Union
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
|
|||||||
intermediate features.
|
intermediate features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Literal, Union
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
from typing import Annotated, Callable, Literal, Optional, Union
|
from typing import Annotated, Callable, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
|
|||||||
*geometric* aspect of target definition from *semantic* classification.
|
*geometric* aspect of target definition from *semantic* classification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Optional, Tuple, Union
|
from typing import Annotated, Literal, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Callable, List, Literal, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -394,7 +394,7 @@ class MaskTime(torch.nn.Module):
|
|||||||
size=num_masks,
|
size=num_masks,
|
||||||
)
|
)
|
||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||||
]
|
]
|
||||||
return mask_time(spec, masks), clip_annotation
|
return mask_time(spec, masks), clip_annotation
|
||||||
|
|
||||||
@ -460,7 +460,7 @@ class MaskFrequency(torch.nn.Module):
|
|||||||
size=num_masks,
|
size=num_masks,
|
||||||
)
|
)
|
||||||
masks = [
|
masks = [
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
(start, start + size) for start, size in zip(mask_start, mask_size, strict=False)
|
||||||
]
|
]
|
||||||
return mask_frequency(spec, masks), clip_annotation
|
return mask_frequency(spec, masks), clip_annotation
|
||||||
|
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class ValidationMetrics(Callback):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
for clip_annotation, clip_dets in zip(
|
for clip_annotation, clip_dets in zip(
|
||||||
clip_annotations, clip_detections
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Sequence
|
from typing import List, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -6,7 +6,6 @@ the specific multi-channel heatmap formats required by the neural network.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from batdetect2.models.types import DetectionModel
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.types import DetectionModel
|
|
||||||
from batdetect2.train.dataset import LabeledDataset
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ def train_loop(
|
|||||||
learning_rate: float = 1e-4,
|
learning_rate: float = 1e-4,
|
||||||
):
|
):
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
DataLoader(validation_dataset, batch_size=32)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@ -36,8 +36,8 @@ def train_loop(
|
|||||||
num_epochs * len(train_loader),
|
num_epochs * len(train_loader),
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for _epoch in range(num_epochs):
|
||||||
train_loss = train_single_epoch(
|
train_single_epoch(
|
||||||
model,
|
model,
|
||||||
train_loader,
|
train_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -60,9 +60,9 @@ def train_single_epoch(
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
spec = batch.spec.to(device)
|
spec = batch.spec.to(device)
|
||||||
detection_heatmap = batch.detection_heatmap.to(device)
|
batch.detection_heatmap.to(device)
|
||||||
class_heatmap = batch.class_heatmap.to(device)
|
batch.class_heatmap.to(device)
|
||||||
size_heatmap = batch.size_heatmap.to(device)
|
batch.size_heatmap.to(device)
|
||||||
|
|
||||||
outputs = model(spec)
|
outputs = model(spec)
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,10 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import batdetect2.train.audio_dataloader as adl
|
||||||
|
import batdetect2.train.evaluate as evl
|
||||||
|
import batdetect2.train.train_split as ts
|
||||||
|
import batdetect2.train.train_utils as tu
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -9,10 +13,6 @@ import torch.utils.data
|
|||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
import batdetect2.train.audio_dataloader as adl
|
|
||||||
import batdetect2.train.evaluate as evl
|
|
||||||
import batdetect2.train.train_split as ts
|
|
||||||
import batdetect2.train.train_utils as tu
|
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
from batdetect2.detector import models, parameters
|
from batdetect2.detector import models, parameters
|
||||||
from batdetect2.train import losses
|
from batdetect2.train import losses
|
||||||
|
|||||||
@ -10,7 +10,7 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
|||||||
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
||||||
else:
|
else:
|
||||||
print("Split not defined")
|
print("Split not defined")
|
||||||
assert False
|
raise AssertionError()
|
||||||
|
|
||||||
return train_sets, test_sets
|
return train_sets, test_sets
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generator, List, Optional, Tuple
|
from typing import Dict, Generator, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -18,7 +18,6 @@ The primary entry points are:
|
|||||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, List, NamedTuple, Optional, TypedDict
|
from typing import Any, List, NamedTuple, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
|
from typing import Generic, List, Protocol, Sequence, TypeVar
|
||||||
|
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
from typing import List, NamedTuple, Protocol, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -10,7 +10,7 @@ pipeline can interact consistently, regardless of the specific underlying
|
|||||||
implementation (e.g., different libraries or custom configurations).
|
implementation (e.g., different libraries or custom configurations).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -13,7 +13,7 @@ throughout BatDetect2.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user