mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create separate preproc/postproc/target instances for model in api
This commit is contained in:
parent
fb3dc3eaf0
commit
32d8c4a9e5
@ -484,11 +484,19 @@ class BatDetect2API:
|
|||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE: Build separate instances of preprocessor and postprocessor
|
||||||
|
# to avoid device mismatch errors
|
||||||
model = build_model(
|
model = build_model(
|
||||||
config=config.model,
|
config=config.model,
|
||||||
targets=targets,
|
targets=build_targets(config=config.model.targets),
|
||||||
preprocessor=preprocessor,
|
preprocessor=build_preprocessor(
|
||||||
postprocessor=postprocessor,
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
config=config.model.preprocess,
|
||||||
|
),
|
||||||
|
postprocessor=build_postprocessor(
|
||||||
|
preprocessor,
|
||||||
|
config=config.model.postprocess,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -29,6 +29,3 @@ class PreprocessorProtocol(Protocol):
|
|||||||
|
|
||||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||||
return self(torch.tensor(wav)).numpy()
|
return self(torch.tensor(wav)).numpy()
|
||||||
|
|
||||||
def cpu(self) -> "PreprocessorProtocol":
|
|
||||||
return self
|
|
||||||
|
|||||||
@ -272,7 +272,7 @@ def build_train_dataset(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
preprocessor=preprocessor.cpu(),
|
preprocessor=preprocessor,
|
||||||
audio_augmentation=audio_augmentation,
|
audio_augmentation=audio_augmentation,
|
||||||
spectrogram_augmentation=spectrogram_augmentation,
|
spectrogram_augmentation=spectrogram_augmentation,
|
||||||
)
|
)
|
||||||
@ -305,7 +305,7 @@ def build_val_dataset(
|
|||||||
clip_annotations,
|
clip_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor.cpu(),
|
preprocessor=preprocessor,
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -73,9 +73,6 @@ def run_train(
|
|||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
if model is not None:
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=model_config.preprocess,
|
config=model_config.preprocess,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user