LR Scheduler takes num of total batches

This commit is contained in:
mbsantiago 2025-08-28 08:52:11 +01:00
parent 34ef9e92a1
commit 93e89ecc46
2 changed files with 98 additions and 14 deletions

View File

@ -1,6 +1,6 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
from soundevent import data from soundevent import data
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching.""" """The geometry representation to use for matching."""
class AffinityFunction(Protocol):
def __call__(
self,
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float: ...
class MatchConfig(BaseConfig): class MatchConfig(BaseConfig):
"""Configuration for matching geometries. """Configuration for matching geometries.
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
} }
def _timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
def _interval_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = max(
0, min(end_time1, end_time2) - max(start_time1, start_time2)
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
"timestamp": _timestamp_affinity,
"interval": _interval_affinity,
}
def match_geometries( def match_geometries(
source: List[data.Geometry], source: List[data.Geometry],
target: List[data.Geometry], target: List[data.Geometry],
@ -81,6 +150,10 @@ def match_geometries(
scores: Optional[List[float]] = None, scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry] geometry_cast = _geometry_cast_functions[config.geometry]
affinity_function = _affinity_functions.get(
config.geometry,
compute_affinity,
)
if config.strategy == "optimal": if config.strategy == "optimal":
return optimal_match( return optimal_match(
@ -98,6 +171,7 @@ def match_geometries(
time_buffer=config.time_buffer, time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer, freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold, affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
scores=scores, scores=scores,
) )
@ -111,6 +185,7 @@ def greedy_match(
target: List[data.Geometry], target: List[data.Geometry],
scores: Optional[List[float]] = None, scores: Optional[List[float]] = None,
affinity_threshold: float = 0.5, affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
time_buffer: float = 0.001, time_buffer: float = 0.001,
freq_buffer: float = 1000, freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]: ) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
@ -168,7 +243,7 @@ def greedy_match(
affinities = np.array( affinities = np.array(
[ [
compute_affinity( affinity_function(
source_geometry, source_geometry,
target_geometry, target_geometry,
time_buffer=time_buffer, time_buffer=time_buffer,

View File

@ -14,7 +14,7 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import build_model from batdetect2.models import Model, build_model
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
RandomExampleSource, RandomExampleSource,
build_augmentations, build_augmentations,
@ -55,17 +55,13 @@ def train(
): ):
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
if model_path is not None: model = build_model(config=config)
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(config)
trainer = build_trainer(config, targets=module.model.targets) trainer = build_trainer(config, targets=model.targets)
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
@ -73,7 +69,7 @@ def train(
val_dataloader = ( val_dataloader = (
build_val_loader( build_val_loader(
val_examples, val_examples,
preprocessor=module.model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
@ -81,6 +77,16 @@ def train(
else None else None
) )
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
model,
config,
batches_per_epoch=len(train_dataloader),
)
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
trainer.fit( trainer.fit(
module, module,
@ -90,14 +96,17 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_training_module(config: FullTrainingConfig) -> TrainingModule: def build_training_module(
model = build_model(config=config) model: Model,
config: FullTrainingConfig,
batches_per_epoch: int,
) -> TrainingModule:
loss = build_loss(config=config.train.loss) loss = build_loss(config=config.train.loss)
return TrainingModule( return TrainingModule(
model=model, model=model,
loss=loss, loss=loss,
learning_rate=config.train.learning_rate, learning_rate=config.train.learning_rate,
t_max=config.train.t_max, t_max=config.train.t_max * batches_per_epoch,
) )