Fix import error

This commit is contained in:
mbsantiago 2025-04-15 20:29:29 +01:00
parent f99653d68f
commit 55eff0cebd
4 changed files with 17 additions and 21 deletions

View File

@ -13,11 +13,10 @@ from batdetect2.preprocess import (
) )
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.targets import ( from batdetect2.targets import (
HeatmapsConfig, LabelConfig,
TagInfo, TagInfo,
TargetConfig, TargetConfig,
) )
from batdetect2.targets.labels import LabelConfig
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import (
TrainPreprocessingConfig, TrainPreprocessingConfig,
) )
@ -77,13 +76,12 @@ def get_training_preprocessing_config(
preprocessing=preprocessing, preprocessing=preprocessing,
target=TargetConfig( target=TargetConfig(
classes=[ classes=[
TagInfo(key="class", value=class_name, label=class_name) TagInfo(key="class", value=class_name)
for class_name in params["class_names"] for class_name in params["class_names"]
], ],
generic_class=TagInfo( generic_class=TagInfo(
key="class", key="class",
value=generic, value=generic,
label=generic,
), ),
include=[ include=[
TagInfo(key="event", value=event) TagInfo(key="event", value=event)
@ -95,12 +93,10 @@ def get_training_preprocessing_config(
], ],
), ),
labels=LabelConfig( labels=LabelConfig(
heatmaps=HeatmapsConfig(
position="bottom-left", position="bottom-left",
time_scale=1 / time_bin_width, time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width, frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"], sigma=params["target_sigma"],
)
), ),
) )

View File

@ -7,7 +7,6 @@ predicted class labels map back to the original tag system for interpretation.
""" """
from batdetect2.targets.labels import ( from batdetect2.targets.labels import (
HeatmapsConfig,
LabelConfig, LabelConfig,
generate_heatmaps, generate_heatmaps,
load_label_config, load_label_config,
@ -47,7 +46,6 @@ from batdetect2.targets.transform import (
__all__ = [ __all__ = [
"DerivationRegistry", "DerivationRegistry",
"DeriveTagRule", "DeriveTagRule",
"HeatmapsConfig",
"LabelConfig", "LabelConfig",
"MapValueRule", "MapValueRule",
"ReplaceRule", "ReplaceRule",
@ -66,6 +64,7 @@ __all__ = [
"get_class_names", "get_class_names",
"get_derivation", "get_derivation",
"get_tag_from_info", "get_tag_from_info",
"get_term_from_key",
"individual", "individual",
"load_label_config", "load_label_config",
"load_target_config", "load_target_config",

View File

@ -51,7 +51,8 @@ class TargetConfig(BaseConfig):
def get_tag_label(tag_info: TagInfo) -> str: def get_tag_label(tag_info: TagInfo) -> str:
return tag_info.label if tag_info.label else tag_info.value # TODO: Review this
return tag_info.value
def get_class_names(classes: List[TagInfo]) -> List[str]: def get_class_names(classes: List[TagInfo]) -> List[str]: