Fix import error

This commit is contained in:
mbsantiago 2025-04-15 20:29:29 +01:00
parent f99653d68f
commit 55eff0cebd
4 changed files with 17 additions and 21 deletions

View File

@ -13,11 +13,10 @@ from batdetect2.preprocess import (
) )
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.targets import ( from batdetect2.targets import (
HeatmapsConfig, LabelConfig,
TagInfo, TagInfo,
TargetConfig, TargetConfig,
) )
from batdetect2.targets.labels import LabelConfig
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import (
TrainPreprocessingConfig, TrainPreprocessingConfig,
) )
@ -77,13 +76,12 @@ def get_training_preprocessing_config(
preprocessing=preprocessing, preprocessing=preprocessing,
target=TargetConfig( target=TargetConfig(
classes=[ classes=[
TagInfo(key="class", value=class_name, label=class_name) TagInfo(key="class", value=class_name)
for class_name in params["class_names"] for class_name in params["class_names"]
], ],
generic_class=TagInfo( generic_class=TagInfo(
key="class", key="class",
value=generic, value=generic,
label=generic,
), ),
include=[ include=[
TagInfo(key="event", value=event) TagInfo(key="event", value=event)
@ -95,12 +93,10 @@ def get_training_preprocessing_config(
], ],
), ),
labels=LabelConfig( labels=LabelConfig(
heatmaps=HeatmapsConfig( position="bottom-left",
position="bottom-left", time_scale=1 / time_bin_width,
time_scale=1 / time_bin_width, frequency_scale=1 / freq_bin_width,
frequency_scale=1 / freq_bin_width, sigma=params["target_sigma"],
sigma=params["target_sigma"],
)
), ),
) )

View File

@ -7,7 +7,6 @@ predicted class labels map back to the original tag system for interpretation.
""" """
from batdetect2.targets.labels import ( from batdetect2.targets.labels import (
HeatmapsConfig,
LabelConfig, LabelConfig,
generate_heatmaps, generate_heatmaps,
load_label_config, load_label_config,
@ -47,7 +46,6 @@ from batdetect2.targets.transform import (
__all__ = [ __all__ = [
"DerivationRegistry", "DerivationRegistry",
"DeriveTagRule", "DeriveTagRule",
"HeatmapsConfig",
"LabelConfig", "LabelConfig",
"MapValueRule", "MapValueRule",
"ReplaceRule", "ReplaceRule",
@ -66,6 +64,7 @@ __all__ = [
"get_class_names", "get_class_names",
"get_derivation", "get_derivation",
"get_tag_from_info", "get_tag_from_info",
"get_term_from_key",
"individual", "individual",
"load_label_config", "load_label_config",
"load_target_config", "load_target_config",

View File

@ -308,17 +308,17 @@ def generate_heatmaps(
Notes Notes
----- -----
* This function expects `sound_events` to be already filtered and * This function expects `sound_events` to be already filtered and
transformed. transformed.
* It includes error handling to skip individual annotations that cause * It includes error handling to skip individual annotations that cause
issues (e.g., missing geometry, out-of-bounds coordinates, encoder issues (e.g., missing geometry, out-of-bounds coordinates, encoder
errors) allowing the rest of the clip to be processed. Warnings or errors) allowing the rest of the clip to be processed. Warnings or
errors are logged. errors are logged.
* The `time_scale` and `frequency_scale` parameters are crucial and must be * The `time_scale` and `frequency_scale` parameters are crucial and must be
set according to the expectations of the specific BatDetect2 model set according to the expectations of the specific BatDetect2 model
architecture being trained. Consult model documentation for required architecture being trained. Consult model documentation for required
units/scales. units/scales.
* Gaussian filtering and normalization are applied only to detection and * Gaussian filtering and normalization are applied only to detection and
class heatmaps, not the size heatmap. class heatmaps, not the size heatmap.
""" """
shape = dict(zip(spec.dims, spec.shape)) shape = dict(zip(spec.dims, spec.shape))

View File

@ -51,7 +51,8 @@ class TargetConfig(BaseConfig):
def get_tag_label(tag_info: TagInfo) -> str: def get_tag_label(tag_info: TagInfo) -> str:
return tag_info.label if tag_info.label else tag_info.value # TODO: Review this
return tag_info.value
def get_class_names(classes: List[TagInfo]) -> List[str]: def get_class_names(classes: List[TagInfo]) -> List[str]: