Temporary remove compat params module

This commit is contained in:
mbsantiago 2025-04-22 09:01:58 +01:00
parent 285c6a3347
commit 9fc713d390

View File

@ -1,152 +1,152 @@
from batdetect2.preprocess import ( # from batdetect2.preprocess import (
AmplitudeScaleConfig, # AmplitudeScaleConfig,
AudioConfig, # AudioConfig,
FrequencyConfig, # FrequencyConfig,
LogScaleConfig, # LogScaleConfig,
PcenConfig, # PcenConfig,
PreprocessingConfig, # PreprocessingConfig,
ResampleConfig, # ResampleConfig,
Scales, # Scales,
SpecSizeConfig, # SpecSizeConfig,
SpectrogramConfig, # SpectrogramConfig,
STFTConfig, # STFTConfig,
) # )
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution # from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.targets import ( # from batdetect2.targets import (
LabelConfig, # LabelConfig,
TagInfo, # TagInfo,
TargetConfig, # TargetConfig,
) # )
from batdetect2.train.preprocess import ( # from batdetect2.train.preprocess import (
TrainPreprocessingConfig, # TrainPreprocessingConfig,
) # )
#
#
def get_spectrogram_scale(scale: str) -> Scales: # def get_spectrogram_scale(scale: str) -> Scales:
if scale == "pcen": # if scale == "pcen":
return PcenConfig() # return PcenConfig()
if scale == "log": # if scale == "log":
return LogScaleConfig() # return LogScaleConfig()
return AmplitudeScaleConfig() # return AmplitudeScaleConfig()
#
#
def get_preprocessing_config(params: dict) -> PreprocessingConfig: # def get_preprocessing_config(params: dict) -> PreprocessingConfig:
return PreprocessingConfig( # return PreprocessingConfig(
audio=AudioConfig( # audio=AudioConfig(
resample=ResampleConfig( # resample=ResampleConfig(
samplerate=params["target_samp_rate"], # samplerate=params["target_samp_rate"],
method="poly", # method="poly",
), # ),
scale=params["scale_raw_audio"], # scale=params["scale_raw_audio"],
center=params["scale_raw_audio"], # center=params["scale_raw_audio"],
duration=None, # duration=None,
), # ),
spectrogram=SpectrogramConfig( # spectrogram=SpectrogramConfig(
stft=STFTConfig( # stft=STFTConfig(
window_duration=params["fft_win_length"], # window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"], # window_overlap=params["fft_overlap"],
window_fn="hann", # window_fn="hann",
), # ),
frequencies=FrequencyConfig( # frequencies=FrequencyConfig(
min_freq=params["min_freq"], # min_freq=params["min_freq"],
max_freq=params["max_freq"], # max_freq=params["max_freq"],
), # ),
scale=get_spectrogram_scale(params["spec_scale"]), # scale=get_spectrogram_scale(params["spec_scale"]),
spectral_mean_substraction=params["denoise_spec_avg"], # spectral_mean_substraction=params["denoise_spec_avg"],
size=SpecSizeConfig( # size=SpecSizeConfig(
height=params["spec_height"], # height=params["spec_height"],
resize_factor=params["resize_factor"], # resize_factor=params["resize_factor"],
), # ),
peak_normalize=params["max_scale_spec"], # peak_normalize=params["max_scale_spec"],
), # ),
) # )
#
#
def get_training_preprocessing_config( # def get_training_preprocessing_config(
params: dict, # params: dict,
) -> TrainPreprocessingConfig: # ) -> TrainPreprocessingConfig:
generic = params["generic_class"][0] # generic = params["generic_class"][0]
preprocessing = get_preprocessing_config(params) # preprocessing = get_preprocessing_config(params)
#
freq_bin_width, time_bin_width = get_spectrogram_resolution( # freq_bin_width, time_bin_width = get_spectrogram_resolution(
preprocessing.spectrogram # preprocessing.spectrogram
) # )
#
return TrainPreprocessingConfig( # return TrainPreprocessingConfig(
preprocessing=preprocessing, # preprocessing=preprocessing,
target=TargetConfig( # target=TargetConfig(
classes=[ # classes=[
TagInfo(key="class", value=class_name) # TagInfo(key="class", value=class_name)
for class_name in params["class_names"] # for class_name in params["class_names"]
], # ],
generic_class=TagInfo( # generic_class=TagInfo(
key="class", # key="class",
value=generic, # value=generic,
), # ),
include=[ # include=[
TagInfo(key="event", value=event) # TagInfo(key="event", value=event)
for event in params["events_of_interest"] # for event in params["events_of_interest"]
], # ],
exclude=[ # exclude=[
TagInfo(key="class", value=value) # TagInfo(key="class", value=value)
for value in params["classes_to_ignore"] # for value in params["classes_to_ignore"]
], # ],
), # ),
labels=LabelConfig( # labels=LabelConfig(
position="bottom-left", # position="bottom-left",
time_scale=1 / time_bin_width, # time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width, # frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"], # sigma=params["target_sigma"],
), # ),
) # )
#
#
# 'standardize_classs_names_ip', # # 'standardize_classs_names_ip',
# 'convert_to_genus', # # 'convert_to_genus',
# 'genus_mapping', # # 'genus_mapping',
# 'standardize_classs_names', # # 'standardize_classs_names',
# 'genus_names', # # 'genus_names',
#
# ['data_dir', # # ['data_dir',
# 'ann_dir', # # 'ann_dir',
# 'train_split', # # 'train_split',
# 'model_name', # # 'model_name',
# 'num_filters', # # 'num_filters',
# 'experiment', # # 'experiment',
# 'model_file_name', # # 'model_file_name',
# 'op_im_dir', # # 'op_im_dir',
# 'op_im_dir_test', # # 'op_im_dir_test',
# 'notes', # # 'notes',
# 'spec_divide_factor', # # 'spec_divide_factor',
# 'detection_overlap', # # 'detection_overlap',
# 'ignore_start_end', # # 'ignore_start_end',
# 'detection_threshold', # # 'detection_threshold',
# 'nms_kernel_size', # # 'nms_kernel_size',
# 'nms_top_k_per_sec', # # 'nms_top_k_per_sec',
# 'aug_prob', # # 'aug_prob',
# 'augment_at_train', # # 'augment_at_train',
# 'augment_at_train_combine', # # 'augment_at_train_combine',
# 'echo_max_delay', # # 'echo_max_delay',
# 'stretch_squeeze_delta', # # 'stretch_squeeze_delta',
# 'mask_max_time_perc', # # 'mask_max_time_perc',
# 'mask_max_freq_perc', # # 'mask_max_freq_perc',
# 'spec_amp_scaling', # # 'spec_amp_scaling',
# 'aug_sampling_rates', # # 'aug_sampling_rates',
# 'train_loss', # # 'train_loss',
# 'det_loss_weight', # # 'det_loss_weight',
# 'size_loss_weight', # # 'size_loss_weight',
# 'class_loss_weight', # # 'class_loss_weight',
# 'individual_loss_weight', # # 'individual_loss_weight',
# 'emb_dim', # # 'emb_dim',
# 'lr', # # 'lr',
# 'batch_size', # # 'batch_size',
# 'num_workers', # # 'num_workers',
# 'num_epochs', # # 'num_epochs',
# 'num_eval_epochs', # # 'num_eval_epochs',
# 'device', # # 'device',
# 'save_test_image_during_train', # # 'save_test_image_during_train',
# 'save_test_image_after_train', # # 'save_test_image_after_train',
# 'train_sets', # # 'train_sets',
# 'test_sets', # # 'test_sets',
# 'class_inv_freq', # # 'class_inv_freq',
# 'ip_height'] # # 'ip_height']