mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow training on existing model
This commit is contained in:
parent
0bf809e376
commit
9fa703b34b
@ -91,6 +91,7 @@ class BatDetect2API:
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
model_config=self.config.model,
|
model_config=self.config.model,
|
||||||
train_config=self.config.train,
|
train_config=self.config.train,
|
||||||
@ -116,7 +117,7 @@ class BatDetect2API:
|
|||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||||
return run_evaluate(
|
return run_evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -183,8 +184,17 @@ class BatDetect2API:
|
|||||||
|
|
||||||
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
|
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
wav = self.audio_loader.load_recording(recording)
|
|
||||||
detections = self.process_audio(wav)
|
predictions = self.process_files(
|
||||||
|
[audio_file],
|
||||||
|
batch_size=self.config.inference.loader.batch_size,
|
||||||
|
)
|
||||||
|
detections = [
|
||||||
|
detection
|
||||||
|
for prediction in predictions
|
||||||
|
for detection in prediction.detections
|
||||||
|
]
|
||||||
|
|
||||||
return ClipDetections(
|
return ClipDetections(
|
||||||
clip=data.Clip(
|
clip=data.Clip(
|
||||||
uuid=recording.uuid,
|
uuid=recording.uuid,
|
||||||
@ -215,7 +225,7 @@ class BatDetect2API:
|
|||||||
|
|
||||||
outputs = self.model.detector(spec)
|
outputs = self.model.detector(spec)
|
||||||
|
|
||||||
detections = self.model.postprocessor(
|
detections = self.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
)[0]
|
)[0]
|
||||||
return self.output_transform.to_detections(
|
return self.output_transform.to_detections(
|
||||||
@ -239,10 +249,14 @@ class BatDetect2API:
|
|||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
config=self.config,
|
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
|
audio_config=self.config.audio,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
|
inference_config=self.config.inference,
|
||||||
|
output_config=self.config.outputs,
|
||||||
|
output_transform=self.output_transform,
|
||||||
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -257,8 +271,11 @@ class BatDetect2API:
|
|||||||
clips,
|
clips,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
|
audio_config=self.config.audio,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
config=self.config,
|
inference_config=self.config.inference,
|
||||||
|
output_config=self.config.outputs,
|
||||||
|
output_transform=self.output_transform,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
@ -294,7 +311,7 @@ class BatDetect2API:
|
|||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
):
|
) -> "BatDetect2API":
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
@ -309,12 +326,6 @@ class BatDetect2API:
|
|||||||
config=config.model.postprocess,
|
config=config.model.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
||||||
|
|
||||||
# NOTE: Better to have a separate instance of preprocessor and
|
|
||||||
# postprocessor as these may be moved to another device.
|
|
||||||
model = build_model(config=config.model)
|
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
formatter = build_output_formatter(
|
||||||
targets,
|
targets,
|
||||||
config=config.outputs.format,
|
config=config.outputs.format,
|
||||||
@ -324,6 +335,19 @@ class BatDetect2API:
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
config=config.evaluation,
|
||||||
|
targets=targets,
|
||||||
|
transform=output_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(
|
||||||
|
config=config.model,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -341,7 +365,7 @@ class BatDetect2API:
|
|||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: BatDetect2Config | None = None,
|
config: BatDetect2Config | None = None,
|
||||||
):
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
@ -368,8 +392,6 @@ class BatDetect2API:
|
|||||||
config=config.model.postprocess,
|
config=config.model.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
formatter = build_output_formatter(
|
||||||
targets,
|
targets,
|
||||||
config=config.outputs.format,
|
config=config.outputs.format,
|
||||||
@ -379,6 +401,16 @@ class BatDetect2API:
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
config=config.evaluation,
|
||||||
|
targets=targets,
|
||||||
|
transform=output_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.preprocessor = preprocessor
|
||||||
|
model.postprocessor = postprocessor
|
||||||
|
model.targets = targets
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def evaluate_command(
|
|||||||
base_dir: Path,
|
base_dir: Path,
|
||||||
config_path: Path | None,
|
config_path: Path | None,
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from batdetect2.logging import build_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import Detection
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ def run_evaluate(
|
|||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from soundevent import data
|
|||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.audio.loader import build_audio_loader
|
from batdetect2.audio.loader import build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.inference.clips import get_clips_from_files
|
from batdetect2.inference.clips import get_clips_from_files
|
||||||
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
|
|||||||
@ -38,10 +38,10 @@ def get_recording_clips(
|
|||||||
discard_empty: bool = True,
|
discard_empty: bool = True,
|
||||||
) -> Sequence[data.Clip]:
|
) -> Sequence[data.Clip]:
|
||||||
start_time = 0
|
start_time = 0
|
||||||
duration = recording.duration
|
recording_duration = recording.duration
|
||||||
hop = duration * (1 - overlap)
|
hop = duration * (1 - overlap)
|
||||||
|
|
||||||
num_clips = int(np.ceil(duration / hop))
|
num_clips = int(np.ceil(recording_duration / hop))
|
||||||
|
|
||||||
if num_clips == 0:
|
if num_clips == 0:
|
||||||
# This should only happen if the clip's duration is zero,
|
# This should only happen if the clip's duration is zero,
|
||||||
@ -53,8 +53,8 @@ def get_recording_clips(
|
|||||||
start = start_time + i * hop
|
start = start_time + i * hop
|
||||||
end = start + duration
|
end = start + duration
|
||||||
|
|
||||||
if end > duration:
|
if end > recording_duration:
|
||||||
empty_duration = end - duration
|
empty_duration = end - recording_duration
|
||||||
|
|
||||||
if empty_duration > max_empty and discard_empty:
|
if empty_duration > max_empty and discard_empty:
|
||||||
# Discard clips that contain too much empty space
|
# Discard clips that contain too much empty space
|
||||||
|
|||||||
75
src/batdetect2/postprocess/remapping.py
Normal file
75
src/batdetect2/postprocess/remapping.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||||
|
|
||||||
|
This module provides utility functions to convert the raw numerical outputs
|
||||||
|
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||||
|
`xarray.DataArray` objects. This step adds coordinate information
|
||||||
|
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||||
|
interpretable in the context of the original audio signal and facilitating
|
||||||
|
subsequent processing steps.
|
||||||
|
|
||||||
|
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||||
|
classification probability maps, size prediction maps, and potentially
|
||||||
|
intermediate features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"to_xarray",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def to_xarray(
|
||||||
|
array: torch.Tensor | np.ndarray,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
name: str = "xarray",
|
||||||
|
extra_dims: List[str] | None = None,
|
||||||
|
extra_coords: Dict[str, np.ndarray] | None = None,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if isinstance(array, torch.Tensor):
|
||||||
|
array = array.detach().cpu().numpy()
|
||||||
|
|
||||||
|
extra_ndims = array.ndim - 2
|
||||||
|
|
||||||
|
if extra_ndims < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Input array must have at least 2 dimensions, "
|
||||||
|
f"got shape {array.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
width = array.shape[-1]
|
||||||
|
height = array.shape[-2]
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
if extra_dims is None:
|
||||||
|
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
|
||||||
|
|
||||||
|
if extra_coords is None:
|
||||||
|
extra_coords = {}
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=array,
|
||||||
|
dims=[
|
||||||
|
*extra_dims,
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
**extra_coords,
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
@ -111,6 +111,7 @@ def load_model_from_checkpoint(
|
|||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: ModelConfig | None = None,
|
model_config: ModelConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
|
model: Model | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
@ -121,4 +122,5 @@ def build_training_module(
|
|||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model_config=model_config.model_dump(mode="json"),
|
model_config=model_config.model_dump(mode="json"),
|
||||||
train_config=train_config.model_dump(mode="json"),
|
train_config=train_config.model_dump(mode="json"),
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from batdetect2.audio.types import AudioLoader
|
|||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -33,6 +33,7 @@ __all__ = [
|
|||||||
def run_train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
|
model: Model | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
@ -57,10 +58,19 @@ def run_train(
|
|||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
train_config = train_config or TrainingConfig()
|
train_config = train_config or TrainingConfig()
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
_validate_model_compatibility(model=model, model_config=model_config)
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
targets = targets or model.targets
|
||||||
|
|
||||||
targets = targets or build_targets(config=model_config.targets)
|
targets = targets or build_targets(config=model_config.targets)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
preprocessor = preprocessor or model.preprocessor
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=model_config.preprocess,
|
config=model_config.preprocess,
|
||||||
@ -98,6 +108,7 @@ def run_train(
|
|||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
@ -124,6 +135,49 @@ def run_train(
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model_compatibility(
|
||||||
|
model: Model,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
) -> None:
|
||||||
|
reference_model = build_model(config=model_config)
|
||||||
|
|
||||||
|
expected_shapes = {
|
||||||
|
key: tuple(value.shape)
|
||||||
|
for key, value in reference_model.state_dict().items()
|
||||||
|
}
|
||||||
|
actual_shapes = {
|
||||||
|
key: tuple(value.shape) for key, value in model.state_dict().items()
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_keys = set(expected_shapes)
|
||||||
|
actual_keys = set(actual_shapes)
|
||||||
|
|
||||||
|
missing_keys = sorted(expected_keys - actual_keys)
|
||||||
|
if missing_keys:
|
||||||
|
key = missing_keys[0]
|
||||||
|
raise ValueError(
|
||||||
|
"Provided model is incompatible with model_config: "
|
||||||
|
f"missing state key '{key}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_keys = sorted(actual_keys - expected_keys)
|
||||||
|
if extra_keys:
|
||||||
|
key = extra_keys[0]
|
||||||
|
raise ValueError(
|
||||||
|
"Provided model is incompatible with model_config: "
|
||||||
|
f"unexpected state key '{key}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, expected_shape in expected_shapes.items():
|
||||||
|
actual_shape = actual_shapes[key]
|
||||||
|
if actual_shape != expected_shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Provided model is incompatible with model_config: "
|
||||||
|
f"shape mismatch for '{key}' (expected {expected_shape}, "
|
||||||
|
f"got {actual_shape})."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -10,7 +11,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig, build_model
|
||||||
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
@ -223,3 +225,40 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
outputs = model(wav.unsqueeze(0))
|
outputs = model(wav.unsqueeze(0))
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_training_module_uses_provided_model() -> None:
|
||||||
|
model = build_model(ModelConfig())
|
||||||
|
|
||||||
|
module = build_training_module(
|
||||||
|
model_config=ModelConfig(),
|
||||||
|
train_config=TrainingConfig(),
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert module.model is model
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_train_rejects_incompatible_model_config(
|
||||||
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
) -> None:
|
||||||
|
model = build_model(ModelConfig())
|
||||||
|
incompatible_config = ModelConfig()
|
||||||
|
incompatible_config.targets.classification_targets.append(
|
||||||
|
TargetClassConfig(
|
||||||
|
name="dummy_class",
|
||||||
|
tags=[data.Tag(key="class", value="Dummy class")],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Provided model is incompatible with model_config",
|
||||||
|
):
|
||||||
|
run_train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=None,
|
||||||
|
model=model,
|
||||||
|
model_config=incompatible_config,
|
||||||
|
train_config=TrainingConfig(),
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user