Fix import error

This commit is contained in:
mbsantiago 2025-04-15 20:29:29 +01:00
parent f99653d68f
commit 55eff0cebd
4 changed files with 17 additions and 21 deletions

View File

@ -13,11 +13,10 @@ from batdetect2.preprocess import (
)
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.targets import (
HeatmapsConfig,
LabelConfig,
TagInfo,
TargetConfig,
)
from batdetect2.targets.labels import LabelConfig
from batdetect2.train.preprocess import (
TrainPreprocessingConfig,
)
@ -77,13 +76,12 @@ def get_training_preprocessing_config(
preprocessing=preprocessing,
target=TargetConfig(
classes=[
TagInfo(key="class", value=class_name, label=class_name)
TagInfo(key="class", value=class_name)
for class_name in params["class_names"]
],
generic_class=TagInfo(
key="class",
value=generic,
label=generic,
),
include=[
TagInfo(key="event", value=event)
@ -95,12 +93,10 @@ def get_training_preprocessing_config(
],
),
labels=LabelConfig(
heatmaps=HeatmapsConfig(
position="bottom-left",
time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"],
)
),
)

View File

@ -7,7 +7,6 @@ predicted class labels map back to the original tag system for interpretation.
"""
from batdetect2.targets.labels import (
HeatmapsConfig,
LabelConfig,
generate_heatmaps,
load_label_config,
@ -47,7 +46,6 @@ from batdetect2.targets.transform import (
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"HeatmapsConfig",
"LabelConfig",
"MapValueRule",
"ReplaceRule",
@ -66,6 +64,7 @@ __all__ = [
"get_class_names",
"get_derivation",
"get_tag_from_info",
"get_term_from_key",
"individual",
"load_label_config",
"load_target_config",

View File

@ -51,7 +51,8 @@ class TargetConfig(BaseConfig):
def get_tag_label(tag_info: TagInfo) -> str:
return tag_info.label if tag_info.label else tag_info.value
# TODO: Review this
return tag_info.value
def get_class_names(classes: List[TagInfo]) -> List[str]: