Create separate preproc/postproc/target instances for model in api

This commit is contained in:
Santiago Martinez Balvanera 2026-03-19 00:23:55 +00:00
parent fb3dc3eaf0
commit 32d8c4a9e5
4 changed files with 13 additions and 11 deletions

View File

@ -484,11 +484,19 @@ class BatDetect2API:
transform=output_transform,
)
# NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors
model = build_model(
config=config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=build_targets(config=config.model.targets),
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.model.postprocess,
),
)
return cls(

View File

@ -29,6 +29,3 @@ class PreprocessorProtocol(Protocol):
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()
def cpu(self) -> "PreprocessorProtocol":
return self

View File

@ -272,7 +272,7 @@ def build_train_dataset(
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
preprocessor=preprocessor.cpu(),
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
@ -305,7 +305,7 @@ def build_val_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor.cpu(),
preprocessor=preprocessor,
clipper=clipper,
)

View File

@ -73,9 +73,6 @@ def run_train(
audio_loader = audio_loader or build_audio_loader(config=audio_config)
if model is not None:
preprocessor = preprocessor or model.preprocessor
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=model_config.preprocess,