mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
LR Scheduler takes num of total batches
This commit is contained in:
parent
34ef9e92a1
commit
93e89ecc46
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Literal, Optional, Tuple
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -23,6 +23,16 @@ MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
|||||||
"""The geometry representation to use for matching."""
|
"""The geometry representation to use for matching."""
|
||||||
|
|
||||||
|
|
||||||
|
class AffinityFunction(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float: ...
|
||||||
|
|
||||||
|
|
||||||
class MatchConfig(BaseConfig):
|
class MatchConfig(BaseConfig):
|
||||||
"""Configuration for matching geometries.
|
"""Configuration for matching geometries.
|
||||||
|
|
||||||
@ -74,6 +84,65 @@ _geometry_cast_functions: Mapping[
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamp_affinity(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeStamp)
|
||||||
|
assert isinstance(geometry2, data.TimeStamp)
|
||||||
|
|
||||||
|
start_time1 = geometry1.coordinates
|
||||||
|
start_time2 = geometry2.coordinates
|
||||||
|
|
||||||
|
a = min(start_time1, start_time2)
|
||||||
|
b = max(start_time1, start_time2)
|
||||||
|
|
||||||
|
if b - a >= 2 * time_buffer:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
intersection = a - b + 2 * time_buffer
|
||||||
|
union = b - a + 2 * time_buffer
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
def _interval_affinity(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeInterval)
|
||||||
|
assert isinstance(geometry2, data.TimeInterval)
|
||||||
|
|
||||||
|
start_time1, end_time1 = geometry1.coordinates
|
||||||
|
start_time2, end_time2 = geometry1.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
intersection = max(
|
||||||
|
0, min(end_time1, end_time2) - max(start_time1, start_time2)
|
||||||
|
)
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
_affinity_functions: Mapping[MatchingGeometry, AffinityFunction] = {
|
||||||
|
"timestamp": _timestamp_affinity,
|
||||||
|
"interval": _interval_affinity,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def match_geometries(
|
def match_geometries(
|
||||||
source: List[data.Geometry],
|
source: List[data.Geometry],
|
||||||
target: List[data.Geometry],
|
target: List[data.Geometry],
|
||||||
@ -81,6 +150,10 @@ def match_geometries(
|
|||||||
scores: Optional[List[float]] = None,
|
scores: Optional[List[float]] = None,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||||
|
affinity_function = _affinity_functions.get(
|
||||||
|
config.geometry,
|
||||||
|
compute_affinity,
|
||||||
|
)
|
||||||
|
|
||||||
if config.strategy == "optimal":
|
if config.strategy == "optimal":
|
||||||
return optimal_match(
|
return optimal_match(
|
||||||
@ -98,6 +171,7 @@ def match_geometries(
|
|||||||
time_buffer=config.time_buffer,
|
time_buffer=config.time_buffer,
|
||||||
freq_buffer=config.frequency_buffer,
|
freq_buffer=config.frequency_buffer,
|
||||||
affinity_threshold=config.affinity_threshold,
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
affinity_function=affinity_function,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,6 +185,7 @@ def greedy_match(
|
|||||||
target: List[data.Geometry],
|
target: List[data.Geometry],
|
||||||
scores: Optional[List[float]] = None,
|
scores: Optional[List[float]] = None,
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
|
affinity_function: AffinityFunction = compute_affinity,
|
||||||
time_buffer: float = 0.001,
|
time_buffer: float = 0.001,
|
||||||
freq_buffer: float = 1000,
|
freq_buffer: float = 1000,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
@ -168,7 +243,7 @@ def greedy_match(
|
|||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
compute_affinity(
|
affinity_function(
|
||||||
source_geometry,
|
source_geometry,
|
||||||
target_geometry,
|
target_geometry,
|
||||||
time_buffer=time_buffer,
|
time_buffer=time_buffer,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -55,17 +55,13 @@ def train(
|
|||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
if model_path is not None:
|
model = build_model(config=config)
|
||||||
logger.debug("Loading model from: {path}", path=model_path)
|
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
|
||||||
else:
|
|
||||||
module = build_training_module(config)
|
|
||||||
|
|
||||||
trainer = build_trainer(config, targets=module.model.targets)
|
trainer = build_trainer(config, targets=model.targets)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -73,7 +69,7 @@ def train(
|
|||||||
val_dataloader = (
|
val_dataloader = (
|
||||||
build_val_loader(
|
build_val_loader(
|
||||||
val_examples,
|
val_examples,
|
||||||
preprocessor=module.model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -81,6 +77,16 @@ def train(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_path is not None:
|
||||||
|
logger.debug("Loading model from: {path}", path=model_path)
|
||||||
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
|
else:
|
||||||
|
module = build_training_module(
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
batches_per_epoch=len(train_dataloader),
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -90,14 +96,17 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(config: FullTrainingConfig) -> TrainingModule:
|
def build_training_module(
|
||||||
model = build_model(config=config)
|
model: Model,
|
||||||
|
config: FullTrainingConfig,
|
||||||
|
batches_per_epoch: int,
|
||||||
|
) -> TrainingModule:
|
||||||
loss = build_loss(config=config.train.loss)
|
loss = build_loss(config=config.train.loss)
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model=model,
|
model=model,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=config.train.t_max,
|
t_max=config.train.t_max * batches_per_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user