From f7d6516550da772648f1eb0e7dc48287c55ebdd5 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 23 Jan 2025 14:08:55 +0000 Subject: [PATCH] WIP --- batdetect2/compat/data.py | 7 +++ batdetect2/data.py | 56 +++++++++++++++----- batdetect2/models/__init__.py | 2 +- batdetect2/models/heads.py | 4 +- batdetect2/post_process.py | 14 ++--- batdetect2/preprocess/audio.py | 2 +- batdetect2/preprocess/spectrogram.py | 3 +- batdetect2/train/augmentations.py | 73 +++++++++++++++++++++++--- batdetect2/train/dataset.py | 16 +++++- batdetect2/train/labels.py | 25 +++++---- batdetect2/train/losses.py | 1 + batdetect2/train/modules.py | 36 ++++++++++++- batdetect2/train/preprocess.py | 14 ++++- batdetect2/train/targets.py | 51 +++++++++++++++--- tests/test_train/test_augmentations.py | 66 +++++++++++++++++++++++ uv.lock | 37 +++++++------ 16 files changed, 337 insertions(+), 70 deletions(-) diff --git a/batdetect2/compat/data.py b/batdetect2/compat/data.py index 5139fb2..b05179f 100644 --- a/batdetect2/compat/data.py +++ b/batdetect2/compat/data.py @@ -215,6 +215,7 @@ def annotation_to_sound_event( def file_annotation_to_clip( file_annotation: FileAnnotation, audio_dir: Optional[PathLike] = None, + label_key: str = "class", ) -> data.Clip: """Convert file annotation to recording.""" audio_dir = audio_dir or Path.cwd() @@ -227,6 +228,12 @@ def file_annotation_to_clip( recording = data.Recording.from_file( full_path, time_expansion=file_annotation.time_exp, + tags=[ + data.Tag( + term=data.term_from_key(label_key), + value=file_annotation.label, + ) + ], ) return data.Clip( diff --git a/batdetect2/data.py b/batdetect2/data.py index 611f0a9..8d63602 100644 --- a/batdetect2/data.py +++ b/batdetect2/data.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Literal, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union from pydantic import Field from soundevent import data, io @@ -8,7 +8,11 @@ from batdetect2.compat.data import ( load_annotation_project_from_dir, load_annotation_project_from_file, ) -from batdetect2.configs import BaseConfig +from batdetect2.configs import BaseConfig, load_config + +__all__ = [ + "load_datasets_from_config", +] class BatDetect2AnnotationFiles(BaseConfig): @@ -23,7 +27,7 @@ class BatDetect2AnnotationFile(BaseConfig): class AOEFAnnotationFile(BaseConfig): format: Literal["aoef"] = "aoef" - annotations_file: Path + path: Path AnnotationFormats = Union[ @@ -44,25 +48,39 @@ class DatasetsConfig(BaseConfig): test: List[DatasetInfo] = Field(default_factory=list) -def load_dataset(info: DatasetInfo) -> data.AnnotationProject: +def load_dataset( + info: DatasetInfo, + audio_dir: Optional[Path] = None, + base_dir: Optional[Path] = None, +) -> data.AnnotationProject: + audio_dir = ( + info.audio_dir if base_dir is None else base_dir / info.audio_dir + ) + + path = ( + info.annotations.path + if base_dir is None + else base_dir / info.annotations.path + ) + if info.annotations.format == "batdetect2": return load_annotation_project_from_dir( - info.annotations.path, + path, name=info.name, - audio_dir=info.audio_dir, + audio_dir=audio_dir, ) if info.annotations.format == "batdetect2_file": return load_annotation_project_from_file( - info.annotations.path, + path, name=info.name, - audio_dir=info.audio_dir, + audio_dir=audio_dir, ) if info.annotations.format == "aoef": return io.load( # type: ignore - info.annotations.annotations_file, - audio_dir=info.audio_dir, + info.annotations.path, + audio_dir=audio_dir, ) raise NotImplementedError( @@ -72,16 +90,30 @@ def load_dataset(info: DatasetInfo) -> data.AnnotationProject: def load_datasets( config: DatasetsConfig, + base_dir: Optional[Path] = None, ) -> Tuple[List[data.ClipAnnotation], List[data.ClipAnnotation]]: test_annotations = [] train_annotations = [] for dataset in config.train: - project = load_dataset(dataset) + project = load_dataset(dataset, base_dir=base_dir) train_annotations.extend(project.clip_annotations) for dataset in config.test: - project = load_dataset(dataset) + project = load_dataset(dataset, base_dir=base_dir) test_annotations.extend(project.clip_annotations) return train_annotations, test_annotations + + +def load_datasets_from_config( + path: data.PathLike, + field: Optional[str] = None, + base_dir: Optional[Path] = None, +): + config = load_config( + path=path, + schema=DatasetsConfig, + field=field, + ) + return load_datasets(config, base_dir=base_dir) diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index be06c42..5883cd8 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -28,7 +28,7 @@ class ModelType(str, Enum): class ModelConfig(BaseConfig): name: ModelType = ModelType.Net2DFast - num_features: int = 128 + num_features: int = 32 def get_backbone( diff --git a/batdetect2/models/heads.py b/batdetect2/models/heads.py index d4281f4..5d7ce3f 100644 --- a/batdetect2/models/heads.py +++ b/batdetect2/models/heads.py @@ -41,8 +41,8 @@ class BBoxHead(nn.Module): self.in_channels = in_channels self.bbox = nn.Conv2d( - self.feature_extractor.out_channels, - 2, + in_channels=self.in_channels, + out_channels=2, kernel_size=1, padding=0, ) diff --git a/batdetect2/post_process.py b/batdetect2/post_process.py index c39994c..871b0a2 100644 --- a/batdetect2/post_process.py +++ b/batdetect2/post_process.py @@ -1,6 +1,6 @@ """Module for postprocessing model outputs.""" -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -30,15 +30,12 @@ class PostprocessConfig(BaseModel): top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) -TagFunction = Callable[[int], List[data.Tag]] - - def postprocess_model_outputs( outputs: ModelOutput, clips: List[data.Clip], classes: List[str], decoder: Callable[[str], List[data.Tag]], - config: PostprocessConfig, + config: Optional[PostprocessConfig] = None, ) -> List[data.ClipPrediction]: """Postprocesses model outputs to generate clip predictions. @@ -68,6 +65,9 @@ def postprocess_model_outputs( ValueError If the number of predictions does not match the number of clips. """ + + config = config or PostprocessConfig() + num_predictions = len(outputs.detection_probs) if num_predictions == 0: @@ -189,7 +189,7 @@ def compute_sound_events_from_outputs( [ data.PredictedTag( tag=tag, - score=class_score.item(), + score=max(min(class_score.item(), 1), 0), ) for tag in corresponding_tags ] @@ -220,7 +220,7 @@ def compute_sound_events_from_outputs( predictions.append( data.SoundEventPrediction( sound_event=sound_event, - score=score.item(), + score=max(min(score.item(), 1), 0), tags=predicted_tags, ) ) diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index 5c613f3..e5cfab2 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -12,7 +12,7 @@ from batdetect2.configs import BaseConfig TARGET_SAMPLERATE_HZ = 256_000 SCALE_RAW_AUDIO = False -DEFAULT_DURATION = 1 +DEFAULT_DURATION = None class ResampleConfig(BaseConfig): diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index 6a619f8..fd0d72a 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -71,11 +71,12 @@ def compute_spectrogram( if config.size.divide_factor: # Need to pad the audio to make sure the spectrogram has a # width compatible with the divide factor + resize_factor = config.size.resize_factor or 1 wav = pad_audio( wav, window_duration=config.fft.window_duration, window_overlap=config.fft.window_overlap, - divide_factor=config.size.divide_factor, + divide_factor=int(config.size.divide_factor / resize_factor), ) spec = stft( diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index 1fc2611..deb8720 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -20,6 +20,37 @@ class AugmentationConfig(BaseConfig): class SubclipConfig(BaseConfig): enable: bool = True duration: Optional[float] = None + width: Optional[int] = 512 + + +def adjust_dataset_width( + example: xr.Dataset, + duration: Optional[float] = None, + width: Optional[int] = None, +) -> xr.Dataset: + step = arrays.get_dim_step(example, "time") # type: ignore + + if width is None: + if duration is None: + raise ValueError("Either duration or width must be provided") + + width = int(np.floor(duration / step)) + + adjusted_arrays = { + name: ops.adjust_dim_width(array, "time", width) + for name, array in example.items() + if name != "audio" + } + + ratio = width / example.spectrogram.sizes["time"] + audio_width = int(example.audio.sizes["audio_time"] * ratio) + adjusted_arrays["audio"] = ops.adjust_dim_width( + example["audio"], + "audio_time", + audio_width, + ) + + return xr.Dataset(data_vars=adjusted_arrays) def select_random_subclip( @@ -30,6 +61,7 @@ def select_random_subclip( ) -> xr.Dataset: """Select a random subclip from a clip.""" step = arrays.get_dim_step(example, "time") # type: ignore + start, stop = arrays.get_dim_range(example, "time") # type: ignore if width is None: if duration is None: @@ -41,15 +73,19 @@ def select_random_subclip( duration = width * step if start_time is None: - start, stop = arrays.get_dim_range(example, "time") # type: ignore - start_time = np.random.uniform(start, stop - duration) + start_time = np.random.uniform(start, max(stop - duration, start)) + + if start_time + duration > stop: + example = adjust_dataset_width(example, width=width) start_index = arrays.get_coord_index( example, # type: ignore "time", start_time, ) + end_index = start_index + width - 1 + start_time = example.time.values[start_index] end_time = example.time.values[end_index] @@ -78,8 +114,15 @@ def mix_examples( if weight is None: weight = np.random.uniform(min_weight, max_weight) - audio2 = other["audio"].values - audio1 = ops.adjust_dim_width(example["audio"], "audio_time", len(audio2)) + audio1 = example["audio"] + + audio2 = ops.adjust_dim_width( + other["audio"], "audio_time", len(audio1) + ).values + + if len(audio2) > len(audio1): + audio2 = audio2[: len(audio1)] + combined = weight * audio1 + (1 - weight) * audio2 spec = compute_spectrogram( @@ -104,11 +147,16 @@ def mix_examples( return xr.Dataset( { "audio": combined, - "spectrogram": spec, + "spectrogram": xr.DataArray( + data=spec.data, + dims=example["spectrogram"].dims, + coords=example["spectrogram"].coords, + ), "detection": detection_heatmap, "class": class_heatmap, "size": size_heatmap, - } + }, + attrs=example.attrs, ) @@ -146,7 +194,14 @@ def add_echo( config=config.spectrogram, ) - return example.assign(audio=audio, spectrogram=spectrogram) + return example.assign( + audio=audio, + spectrogram=xr.DataArray( + data=spectrogram.data, + dims=example["spectrogram"].dims, + coords=example["spectrogram"].coords, + ), + ) class VolumeAugmentationConfig(AugmentationConfig): @@ -340,15 +395,17 @@ def augment_example( example = select_random_subclip( example, duration=config.subclip.duration, + width=config.subclip.width, ) - if should_apply(config.mix) and others is not None: + if should_apply(config.mix) and (others is not None): other = others() if config.subclip.enable: other = select_random_subclip( other, duration=config.subclip.duration, + width=config.subclip.width, ) example = mix_examples( diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 1828ef2..3fdcf24 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -74,8 +74,20 @@ class LabeledDataset(Dataset): ) @classmethod - def from_directory(cls, directory: PathLike, extension: str = ".nc"): - return cls(get_files(directory, extension)) + def from_directory( + cls, + directory: PathLike, + extension: str = ".nc", + augment: bool = False, + preprocessing_config: Optional[PreprocessingConfig] = None, + augmentation_config: Optional[AugmentationsConfig] = None, + ): + return cls( + get_files(directory, extension), + augment=augment, + preprocessing_config=preprocessing_config, + augmentation_config=augmentation_config, + ) def get_random_example(self) -> xr.Dataset: idx = np.random.randint(0, len(self)) diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index 2a38892..326138f 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Callable, List, Optional, Sequence, Tuple import numpy as np @@ -22,7 +23,7 @@ def generate_heatmaps( sound_events: Sequence[data.SoundEventAnnotation], spec: xr.DataArray, class_names: List[str], - encoder: Callable[[data.SoundEventAnnotation], Optional[str]], + encoder: Callable[[Iterable[data.Tag]], Optional[str]], target_sigma: float = 3.0, position: Positions = "bottom-left", time_scale: float = 1000.0, @@ -64,23 +65,29 @@ def generate_heatmaps( time, frequency = geometry.get_geometry_point(geom, position=position) # Set 1.0 at the position of the sound event in the detection heatmap - detection_heatmap = arrays.set_value_at_pos( - detection_heatmap, - 1.0, - time=time, - frequency=frequency, - ) + try: + detection_heatmap = arrays.set_value_at_pos( + detection_heatmap, + 1.0, + time=time, + frequency=frequency, + ) + except KeyError: + # Skip the sound event if the position is outside the spectrogram + continue # Set the size of the sound event at the position in the size heatmap start_time, low_freq, end_time, high_freq = geometry.compute_bounds( geom ) + size = np.array( [ (end_time - start_time) * time_scale, (high_freq - low_freq) * frequency_scale, ] ) + size_heatmap = arrays.set_value_at_pos( size_heatmap, size, @@ -89,14 +96,12 @@ def generate_heatmaps( ) # Get the class name of the sound event - class_name = encoder(sound_event_annotation) + class_name = encoder(sound_event_annotation.tags) if class_name is None: # If the label is None skip the sound event continue - # Set 1.0 at the position and category of the sound event in the class - # heatmap class_heatmap = arrays.set_value_at_pos( class_heatmap, 1.0, diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index e88fbfb..22ac986 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -6,6 +6,7 @@ from pydantic import Field from batdetect2.configs import BaseConfig from batdetect2.models.typing import ModelOutput +from batdetect2.plot import detection from batdetect2.train.dataset import TrainExample diff --git a/batdetect2/train/modules.py b/batdetect2/train/modules.py index 5ae5631..be3ef1d 100644 --- a/batdetect2/train/modules.py +++ b/batdetect2/train/modules.py @@ -13,6 +13,7 @@ from batdetect2.models import ( get_backbone, ) from batdetect2.models.typing import ModelOutput +from batdetect2.post_process import PostprocessConfig from batdetect2.preprocess import PreprocessingConfig from batdetect2.train.dataset import TrainExample from batdetect2.train.losses import LossConfig, compute_loss @@ -37,6 +38,9 @@ class ModuleConfig(BaseConfig): preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) + postprocessing: PostprocessConfig = Field( + default_factory=PostprocessConfig + ) class DetectorModel(L.LightningModule): @@ -50,8 +54,9 @@ class DetectorModel(L.LightningModule): self.config = config or ModuleConfig() self.save_hyperparameters() + size = self.config.preprocessing.spectrogram.size self.backbone = get_backbone( - input_height=self.config.preprocessing.spectrogram.size.height, + input_height=int(size.height * (size.resize_factor or 1)), config=self.config.backbone, ) @@ -62,11 +67,13 @@ class DetectorModel(L.LightningModule): self.bbox = BBoxHead(in_channels=self.backbone.out_channels) - conf = self.training_config.loss.classification + conf = self.config.train.loss.classification self.class_weights = ( torch.tensor(conf.class_weights) if conf.class_weights else None ) + self.validation_predictions = [] + def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore features = self.backbone(spec) detection_probs, classification_probs = self.classifier(features) @@ -86,8 +93,33 @@ class DetectorModel(L.LightningModule): conf=self.config.train.loss, class_weights=self.class_weights, ) + + self.log("train/loss/total", losses.total, prog_bar=True, logger=True) + self.log("train/loss/detection", losses.total, logger=True) + self.log("train/loss/size", losses.total, logger=True) + self.log("train/loss/classification", losses.total, logger=True) + return losses.total + def validation_step(self, batch: TrainExample, batch_idx: int) -> None: + outputs = self.forward(batch.spec) + + losses = compute_loss( + batch, + outputs, + conf=self.config.train.loss, + class_weights=self.class_weights, + ) + + self.log("val/loss/total", losses.total, prog_bar=True, logger=True) + self.log("val/loss/detection", losses.total, logger=True) + self.log("val/loss/size", losses.total, logger=True) + self.log("val/loss/classification", losses.total, logger=True) + + dataloaders = self.trainer.val_dataloaders + print(dataloaders) + + def configure_optimizers(self): conf = self.config.train.optimizer optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index fbcad4c..2817e6d 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -62,12 +62,16 @@ def generate_train_example( include=config.target.include, exclude=config.target.exclude, ) + selected_events = [ event for event in clip_annotation.sound_events if filter_fn(event) ] + encoder = build_encoder( + config.target.classes, + replacement_rules=config.target.replace, + ) class_names = get_class_names(config.target.classes) - encoder = build_encoder(config.target.classes) detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( selected_events, @@ -172,5 +176,11 @@ def preprocess_single_annotation( if path.is_file() and replace: path.unlink() - sample = generate_train_example(clip_annotation, config=config) + try: + sample = generate_train_example(clip_annotation, config=config) + except Exception as error: + raise RuntimeError( + f"Failed to process annotation: {clip_annotation.uuid}" + ) from error + save_to_file(sample, path) diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py index c92b724..c0771fa 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/train/targets.py @@ -1,13 +1,22 @@ +from collections.abc import Iterable from functools import partial +from pathlib import Path from typing import Callable, List, Optional, Set from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig +from batdetect2.configs import BaseConfig, load_config from batdetect2.terms import TagInfo, get_tag_from_info +class ReplaceConfig(BaseConfig): + """Configuration for replacing tags.""" + + original: TagInfo + replacement: TagInfo + + class TargetConfig(BaseConfig): """Configuration for target generation.""" @@ -23,6 +32,7 @@ class TargetConfig(BaseConfig): include: Optional[List[TagInfo]] = Field( default_factory=lambda: [TagInfo(key="event", value="Echolocation")] ) + exclude: Optional[List[TagInfo]] = Field( default_factory=lambda: [ TagInfo(key="class", value=""), @@ -31,6 +41,8 @@ class TargetConfig(BaseConfig): ] ) + replace: Optional[List[ReplaceConfig]] = None + def build_sound_event_filter( include: Optional[List[TagInfo]] = None, @@ -57,9 +69,24 @@ def get_class_names(classes: List[TagInfo]) -> List[str]: return sorted({get_tag_label(tag) for tag in classes}) +def build_replacer( + rules: List[ReplaceConfig], +) -> Callable[[data.Tag], data.Tag]: + mapping = { + get_tag_from_info(rule.original): get_tag_from_info(rule.replacement) + for rule in rules + } + + def replacer(tag: data.Tag) -> data.Tag: + return mapping.get(tag, tag) + + return replacer + + def build_encoder( classes: List[TagInfo], -) -> Callable[[data.SoundEventAnnotation], Optional[str]]: + replacement_rules: Optional[List[ReplaceConfig]] = None, +) -> Callable[[Iterable[data.Tag]], Optional[str]]: target_tags = set([get_tag_from_info(tag) for tag in classes]) tag_mapping = { @@ -67,12 +94,16 @@ def build_encoder( for tag, tag_info in zip(target_tags, classes) } - def encoder( - sound_event_annotation: data.SoundEventAnnotation, - ) -> Optional[str]: - tags = set(sound_event_annotation.tags) + replacer = ( + build_replacer(replacement_rules) if replacement_rules else lambda x: x + ) - intersection = tags & target_tags + def encoder( + tags: Iterable[data.Tag], + ) -> Optional[str]: + sanitized_tags = {replacer(tag) for tag in tags} + + intersection = sanitized_tags & target_tags if not intersection: return None @@ -115,6 +146,12 @@ def filter_sound_event( return True +def load_target_config( + path: Path, field: Optional[str] = None +) -> TargetConfig: + return load_config(path, schema=TargetConfig, field=field) + + DEFAULT_SPECIES_LIST = [ "Barbastellus barbastellus", "Eptesicus serotinus", diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index bc2a1a1..cc92621 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -1,10 +1,12 @@ from collections.abc import Callable +import numpy as np import xarray as xr from soundevent import data from batdetect2.train.augmentations import ( add_echo, + adjust_dataset_width, mix_examples, select_random_subclip, ) @@ -68,3 +70,67 @@ def test_selected_random_subclip_has_the_correct_width( subclip = select_random_subclip(original, width=100) assert subclip["spectrogram"].shape[1] == 100 + + +def test_adjust_dataset_width(): + height = 128 + width = 512 + samplerate = 48_000 + + times = np.linspace(0, 1, width) + + audio_times = np.linspace(0, 1, samplerate) + frequency = np.linspace(0, 24_000, height) + + width_subset = 356 + audio_width_subset = int(samplerate * width_subset / width) + + times_subset = times[:width_subset] + audio_times_subset = audio_times[:audio_width_subset] + dimensions = ["width", "height"] + class_names = [f"species_{i}" for i in range(17)] + + spectrogram = np.random.random([height, width_subset]) + sizes = np.random.random([len(dimensions), height, width_subset]) + classes = np.random.random([len(class_names), height, width_subset]) + audio = np.random.random([int(samplerate * width_subset / width)]) + + dataset = xr.Dataset( + data_vars={ + "audio": (("audio_time",), audio), + "spectrogram": (("frequency", "time"), spectrogram), + "sizes": (("dimension", "frequency", "time"), sizes), + "classes": (("class", "frequency", "time"), classes), + }, + coords={ + "audio_time": audio_times_subset, + "time": times_subset, + "frequency": frequency, + "dimension": dimensions, + "class": class_names, + }, + ) + + adjusted = adjust_dataset_width(dataset, width=width) + + # Spectrogram was adjusted correctly + assert np.isclose(adjusted["spectrogram"].time, times).all() + assert (adjusted["spectrogram"].frequency == frequency).all() + + # Sizes was adjusted correctly + assert np.isclose(adjusted["sizes"].time, times).all() + assert (adjusted["sizes"].frequency == frequency).all() + assert list(adjusted["sizes"].dimension.values) == dimensions + + # Sizes was adjusted correctly + assert np.isclose(adjusted["classes"].time, times).all() + assert (adjusted["sizes"].frequency == frequency).all() + assert list(adjusted["classes"]["class"].values) == class_names + + # Audio time was adjusted corretly + assert np.isclose( + len(adjusted["audio"].audio_time), len(audio_times), atol=2 + ) + assert np.isclose( + adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3 + ) diff --git a/uv.lock b/uv.lock index c7d404e..6ca08f6 100644 --- a/uv.lock +++ b/uv.lock @@ -548,7 +548,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -630,6 +630,9 @@ name = "configobj" version = "5.0.9" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/f5/c4/c7f9e41bc2e5f8eeae4a08a01c91b2aea3dfab40a3e14b25e87e7db8d501/configobj-5.0.9.tar.gz", hash = "sha256:03c881bbf23aa07bccf1b837005975993c4ab4427ba57f959afdd9d1a2386848", size = 101518 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/c4/0679472c60052c27efa612b4cd3ddd2a23e885dcdc73461781d2c802d39e/configobj-5.0.9-py2.py3-none-any.whl", hash = "sha256:1ba10c5b6ee16229c79a05047aeda2b55eb4e80d7c7d8ecf17ec1ca600c79882", size = 35615 }, +] [[package]] name = "contourpy" @@ -1433,7 +1436,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -3175,6 +3178,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/b7/20c6f3c0b656fe609675d69bc135c03aac9e3865912444be6339207b6648/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76", size = 686712 }, { url = "https://files.pythonhosted.org/packages/cd/11/d12dbf683471f888d354dac59593873c2b45feb193c5e3e0f2ebf85e68b9/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6", size = 663936 }, { url = "https://files.pythonhosted.org/packages/72/14/4c268f5077db5c83f743ee1daeb236269fa8577133a5cfa49f8b382baf13/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd", size = 696580 }, + { url = "https://files.pythonhosted.org/packages/30/fc/8cd12f189c6405a4c1cf37bd633aa740a9538c8e40497c231072d0fef5cf/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a", size = 663393 }, { url = "https://files.pythonhosted.org/packages/80/29/c0a017b704aaf3cbf704989785cd9c5d5b8ccec2dae6ac0c53833c84e677/ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da", size = 100326 }, { url = "https://files.pythonhosted.org/packages/3a/65/fa39d74db4e2d0cd252355732d966a460a41cd01c6353b820a0952432839/ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28", size = 118079 }, { url = "https://files.pythonhosted.org/packages/fb/8f/683c6ad562f558cbc4f7c029abcd9599148c51c54b5ef0f24f2638da9fbb/ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6", size = 132224 }, @@ -3183,6 +3187,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/29/88c2567bc893c84d88b4c48027367c3562ae69121d568e8a3f3a8d363f4d/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52", size = 703012 }, { url = "https://files.pythonhosted.org/packages/11/46/879763c619b5470820f0cd6ca97d134771e502776bc2b844d2adb6e37753/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642", size = 704352 }, { url = "https://files.pythonhosted.org/packages/02/80/ece7e6034256a4186bbe50dee28cd032d816974941a6abf6a9d65e4228a7/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2", size = 737344 }, + { url = "https://files.pythonhosted.org/packages/f0/ca/e4106ac7e80efbabdf4bf91d3d32fc424e41418458251712f5672eada9ce/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3", size = 714498 }, { url = "https://files.pythonhosted.org/packages/67/58/b1f60a1d591b771298ffa0428237afb092c7f29ae23bad93420b1eb10703/ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4", size = 100205 }, { url = "https://files.pythonhosted.org/packages/b4/4f/b52f634c9548a9291a70dfce26ca7ebce388235c93588a1068028ea23fcc/ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb", size = 118185 }, { url = "https://files.pythonhosted.org/packages/48/41/e7a405afbdc26af961678474a55373e1b323605a4f5e2ddd4a80ea80f628/ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632", size = 133433 }, @@ -3191,6 +3196,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a9/d39f3c5ada0a3bb2870d7db41901125dbe2434fa4f12ca8c5b83a42d7c53/ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd", size = 706497 }, { url = "https://files.pythonhosted.org/packages/b0/fa/097e38135dadd9ac25aecf2a54be17ddf6e4c23e43d538492a90ab3d71c6/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31", size = 698042 }, { url = "https://files.pythonhosted.org/packages/ec/d5/a659ca6f503b9379b930f13bc6b130c9f176469b73b9834296822a83a132/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680", size = 745831 }, + { url = "https://files.pythonhosted.org/packages/db/5d/36619b61ffa2429eeaefaab4f3374666adf36ad8ac6330d855848d7d36fd/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d", size = 715692 }, { url = "https://files.pythonhosted.org/packages/b1/82/85cb92f15a4231c89b95dfe08b09eb6adca929ef7df7e17ab59902b6f589/ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5", size = 98777 }, { url = "https://files.pythonhosted.org/packages/d7/8f/c3654f6f1ddb75daf3922c3d8fc6005b1ab56671ad56ffb874d908bfa668/ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4", size = 115523 }, { url = "https://files.pythonhosted.org/packages/e5/46/ccdef7a84ad745c37cb3d9a81790f28fbc9adf9c237dba682017b123294e/ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987", size = 131834 }, @@ -3199,6 +3205,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/1c/23497017c554fc06ff5701b29355522cff850f626337fff35d9ab352cb18/ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7", size = 689072 }, { url = "https://files.pythonhosted.org/packages/68/e6/f3d4ff3223f9ea49c3b7169ec0268e42bd49f87c70c0e3e853895e4a7ae2/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285", size = 667091 }, { url = "https://files.pythonhosted.org/packages/84/62/ead07043527642491e5011b143f44b81ef80f1025a96069b7210e0f2f0f3/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed", size = 699111 }, + { url = "https://files.pythonhosted.org/packages/52/b3/fe4d84446f7e4887e3bea7ceff0a7df23790b5ed625f830e79ace88ebefb/ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2c59aa6170b990d8d2719323e628aaf36f3bfbc1c26279c0eeeb24d05d2d11c7", size = 666365 }, { url = "https://files.pythonhosted.org/packages/6e/b3/7feb99a00bfaa5c6868617bb7651308afde85e5a0b23cd187fe5de65feeb/ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12", size = 100863 }, { url = "https://files.pythonhosted.org/packages/93/07/de635108684b7a5bb06e432b0930c5a04b6c59efe73bd966d8db3cc208f2/ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b", size = 118653 }, ] @@ -3632,19 +3639,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -3758,7 +3765,7 @@ name = "tqdm" version = "4.67.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e8/4f/0153c21dc5779a49a0598c445b1978126b1344bab9ee71e53e44877e14e0/tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a", size = 169739 } wheels = [