mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fix import error
This commit is contained in:
parent
f99653d68f
commit
55eff0cebd
@ -13,11 +13,10 @@ from batdetect2.preprocess import (
|
|||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
HeatmapsConfig,
|
LabelConfig,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.labels import LabelConfig
|
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
TrainPreprocessingConfig,
|
TrainPreprocessingConfig,
|
||||||
)
|
)
|
||||||
@ -77,13 +76,12 @@ def get_training_preprocessing_config(
|
|||||||
preprocessing=preprocessing,
|
preprocessing=preprocessing,
|
||||||
target=TargetConfig(
|
target=TargetConfig(
|
||||||
classes=[
|
classes=[
|
||||||
TagInfo(key="class", value=class_name, label=class_name)
|
TagInfo(key="class", value=class_name)
|
||||||
for class_name in params["class_names"]
|
for class_name in params["class_names"]
|
||||||
],
|
],
|
||||||
generic_class=TagInfo(
|
generic_class=TagInfo(
|
||||||
key="class",
|
key="class",
|
||||||
value=generic,
|
value=generic,
|
||||||
label=generic,
|
|
||||||
),
|
),
|
||||||
include=[
|
include=[
|
||||||
TagInfo(key="event", value=event)
|
TagInfo(key="event", value=event)
|
||||||
@ -95,12 +93,10 @@ def get_training_preprocessing_config(
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
labels=LabelConfig(
|
labels=LabelConfig(
|
||||||
heatmaps=HeatmapsConfig(
|
position="bottom-left",
|
||||||
position="bottom-left",
|
time_scale=1 / time_bin_width,
|
||||||
time_scale=1 / time_bin_width,
|
frequency_scale=1 / freq_bin_width,
|
||||||
frequency_scale=1 / freq_bin_width,
|
sigma=params["target_sigma"],
|
||||||
sigma=params["target_sigma"],
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ predicted class labels map back to the original tag system for interpretation.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from batdetect2.targets.labels import (
|
from batdetect2.targets.labels import (
|
||||||
HeatmapsConfig,
|
|
||||||
LabelConfig,
|
LabelConfig,
|
||||||
generate_heatmaps,
|
generate_heatmaps,
|
||||||
load_label_config,
|
load_label_config,
|
||||||
@ -47,7 +46,6 @@ from batdetect2.targets.transform import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"DerivationRegistry",
|
"DerivationRegistry",
|
||||||
"DeriveTagRule",
|
"DeriveTagRule",
|
||||||
"HeatmapsConfig",
|
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
"MapValueRule",
|
"MapValueRule",
|
||||||
"ReplaceRule",
|
"ReplaceRule",
|
||||||
@ -66,6 +64,7 @@ __all__ = [
|
|||||||
"get_class_names",
|
"get_class_names",
|
||||||
"get_derivation",
|
"get_derivation",
|
||||||
"get_tag_from_info",
|
"get_tag_from_info",
|
||||||
|
"get_term_from_key",
|
||||||
"individual",
|
"individual",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
@ -308,17 +308,17 @@ def generate_heatmaps(
|
|||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
* This function expects `sound_events` to be already filtered and
|
* This function expects `sound_events` to be already filtered and
|
||||||
transformed.
|
transformed.
|
||||||
* It includes error handling to skip individual annotations that cause
|
* It includes error handling to skip individual annotations that cause
|
||||||
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
|
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
|
||||||
errors) allowing the rest of the clip to be processed. Warnings or
|
errors) allowing the rest of the clip to be processed. Warnings or
|
||||||
errors are logged.
|
errors are logged.
|
||||||
* The `time_scale` and `frequency_scale` parameters are crucial and must be
|
* The `time_scale` and `frequency_scale` parameters are crucial and must be
|
||||||
set according to the expectations of the specific BatDetect2 model
|
set according to the expectations of the specific BatDetect2 model
|
||||||
architecture being trained. Consult model documentation for required
|
architecture being trained. Consult model documentation for required
|
||||||
units/scales.
|
units/scales.
|
||||||
* Gaussian filtering and normalization are applied only to detection and
|
* Gaussian filtering and normalization are applied only to detection and
|
||||||
class heatmaps, not the size heatmap.
|
class heatmaps, not the size heatmap.
|
||||||
"""
|
"""
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
|
|
||||||
|
@ -51,7 +51,8 @@ class TargetConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
def get_tag_label(tag_info: TagInfo) -> str:
|
def get_tag_label(tag_info: TagInfo) -> str:
|
||||||
return tag_info.label if tag_info.label else tag_info.value
|
# TODO: Review this
|
||||||
|
return tag_info.value
|
||||||
|
|
||||||
|
|
||||||
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||||
|
Loading…
Reference in New Issue
Block a user