mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
LR Scheduler takes num of total batches
This commit is contained in:
parent
34ef9e92a1
commit
93e89ecc46
@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
from typing import List, Literal, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
|
||||
|
||||
class AffinityFunction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float: ...
|
||||
|
||||
|
||||
class MatchConfig(BaseConfig):
|
||||
"""Configuration for matching geometries.
|
||||
|
||||
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
|
||||
}
|
||||
|
||||
|
||||
def _timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
|
||||
|
||||
def _interval_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = max(
|
||||
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||
)
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
||||
"timestamp": _timestamp_affinity,
|
||||
"interval": _interval_affinity,
|
||||
}
|
||||
|
||||
|
||||
def match_geometries(
|
||||
source: List[data.Geometry],
|
||||
target: List[data.Geometry],
|
||||
@ -81,6 +150,10 @@ def match_geometries(
|
||||
scores: Optional[List[float]] = None,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||
affinity_function = _affinity_functions.get(
|
||||
config.geometry,
|
||||
compute_affinity,
|
||||
)
|
||||
|
||||
if config.strategy == "optimal":
|
||||
return optimal_match(
|
||||
@ -98,6 +171,7 @@ def match_geometries(
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.frequency_buffer,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
scores=scores,
|
||||
)
|
||||
|
||||
@ -111,6 +185,7 @@ def greedy_match(
|
||||
target: List[data.Geometry],
|
||||
scores: Optional[List[float]] = None,
|
||||
affinity_threshold: float = 0.5,
|
||||
affinity_function: AffinityFunction = compute_affinity,
|
||||
time_buffer: float = 0.001,
|
||||
freq_buffer: float = 1000,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
@ -168,7 +243,7 @@ def greedy_match(
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
compute_affinity(
|
||||
affinity_function(
|
||||
source_geometry,
|
||||
target_geometry,
|
||||
time_buffer=time_buffer,
|
||||
|
||||
@ -14,7 +14,7 @@ from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.train.augmentations import (
|
||||
RandomExampleSource,
|
||||
build_augmentations,
|
||||
@ -55,17 +55,13 @@ def train(
|
||||
):
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(config)
|
||||
model = build_model(config=config)
|
||||
|
||||
trainer = build_trainer(config, targets=module.model.targets)
|
||||
trainer = build_trainer(config, targets=model.targets)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
preprocessor=module.model.preprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
config=config.train,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
@ -73,7 +69,7 @@ def train(
|
||||
val_dataloader = (
|
||||
build_val_loader(
|
||||
val_examples,
|
||||
preprocessor=module.model.preprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
config=config.train,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
@ -81,6 +77,16 @@ def train(
|
||||
else None
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(
|
||||
model,
|
||||
config,
|
||||
batches_per_epoch=len(train_dataloader),
|
||||
)
|
||||
|
||||
logger.info("Starting main training loop...")
|
||||
trainer.fit(
|
||||
module,
|
||||
@ -90,14 +96,17 @@ def train(
|
||||
logger.info("Training complete.")
|
||||
|
||||
|
||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
||||
model = build_model(config=config)
|
||||
def build_training_module(
|
||||
model: Model,
|
||||
config: FullTrainingConfig,
|
||||
batches_per_epoch: int,
|
||||
) -> TrainingModule:
|
||||
loss = build_loss(config=config.train.loss)
|
||||
return TrainingModule(
|
||||
model=model,
|
||||
loss=loss,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=config.train.t_max,
|
||||
t_max=config.train.t_max * batches_per_epoch,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user