mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fix import error
This commit is contained in:
parent
f99653d68f
commit
55eff0cebd
@ -13,11 +13,10 @@ from batdetect2.preprocess import (
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.targets import (
|
||||
HeatmapsConfig,
|
||||
LabelConfig,
|
||||
TagInfo,
|
||||
TargetConfig,
|
||||
)
|
||||
from batdetect2.targets.labels import LabelConfig
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
@ -77,13 +76,12 @@ def get_training_preprocessing_config(
|
||||
preprocessing=preprocessing,
|
||||
target=TargetConfig(
|
||||
classes=[
|
||||
TagInfo(key="class", value=class_name, label=class_name)
|
||||
TagInfo(key="class", value=class_name)
|
||||
for class_name in params["class_names"]
|
||||
],
|
||||
generic_class=TagInfo(
|
||||
key="class",
|
||||
value=generic,
|
||||
label=generic,
|
||||
),
|
||||
include=[
|
||||
TagInfo(key="event", value=event)
|
||||
@ -95,12 +93,10 @@ def get_training_preprocessing_config(
|
||||
],
|
||||
),
|
||||
labels=LabelConfig(
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,6 @@ predicted class labels map back to the original tag system for interpretation.
|
||||
"""
|
||||
|
||||
from batdetect2.targets.labels import (
|
||||
HeatmapsConfig,
|
||||
LabelConfig,
|
||||
generate_heatmaps,
|
||||
load_label_config,
|
||||
@ -47,7 +46,6 @@ from batdetect2.targets.transform import (
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
"DeriveTagRule",
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"MapValueRule",
|
||||
"ReplaceRule",
|
||||
@ -66,6 +64,7 @@ __all__ = [
|
||||
"get_class_names",
|
||||
"get_derivation",
|
||||
"get_tag_from_info",
|
||||
"get_term_from_key",
|
||||
"individual",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
|
@ -51,7 +51,8 @@ class TargetConfig(BaseConfig):
|
||||
|
||||
|
||||
def get_tag_label(tag_info: TagInfo) -> str:
|
||||
return tag_info.label if tag_info.label else tag_info.value
|
||||
# TODO: Review this
|
||||
return tag_info.value
|
||||
|
||||
|
||||
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||
|
Loading…
Reference in New Issue
Block a user