{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cfb0b360-a204-4c27-a18f-3902e8758879", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:02.699871Z", "iopub.status.busy": "2024-11-19T17:33:02.699590Z", "iopub.status.idle": "2024-11-19T17:33:02.710312Z", "shell.execute_reply": "2024-11-19T17:33:02.709798Z", "shell.execute_reply.started": "2024-11-19T17:33:02.699839Z" } }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "326c5432-94e6-4abf-a332-fe902559461b", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:02.711324Z", "iopub.status.busy": "2024-11-19T17:33:02.711067Z", "iopub.status.idle": "2024-11-19T17:33:09.092380Z", "shell.execute_reply": "2024-11-19T17:33:09.091830Z", "shell.execute_reply.started": "2024-11-19T17:33:02.711304Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from pathlib import Path\n", "from typing import List, Optional\n", "import torch\n", "\n", "import pytorch_lightning as pl\n", "from batdetect2.train.modules import DetectorModel\n", "from batdetect2.train.augmentations import (\n", " add_echo,\n", " select_random_subclip,\n", " warp_spectrogram,\n", ")\n", "from batdetect2.train.dataset import LabeledDataset, get_files\n", "from batdetect2.train.preprocess import PreprocessingConfig\n", "from soundevent import data\n", "import matplotlib.pyplot as plt\n", "from soundevent.types import ClassMapper\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "markdown", "id": "9402a473-0b25-4123-9fa8-ad1f71a4237a", "metadata": { "execution": { "iopub.execute_input": "2024-11-18T22:39:12.395329Z", "iopub.status.busy": "2024-11-18T22:39:12.393444Z", "iopub.status.idle": "2024-11-18T22:39:12.405938Z", "shell.execute_reply": "2024-11-18T22:39:12.402980Z", "shell.execute_reply.started": "2024-11-18T22:39:12.395236Z" } }, "source": [ "## Training Datasets" ] }, { "cell_type": "code", "execution_count": 3, "id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.093487Z", "iopub.status.busy": "2024-11-19T17:33:09.092990Z", "iopub.status.idle": "2024-11-19T17:33:09.121636Z", "shell.execute_reply": "2024-11-19T17:33:09.121143Z", "shell.execute_reply.started": "2024-11-19T17:33:09.093459Z" } }, "outputs": [], "source": [ "data_dir = Path.cwd().parent / \"example_data\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "d5131ae9-2efd-4758-b6e5-189a6d90789b", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.122685Z", "iopub.status.busy": "2024-11-19T17:33:09.122270Z", "iopub.status.idle": "2024-11-19T17:33:09.151386Z", "shell.execute_reply": "2024-11-19T17:33:09.150788Z", "shell.execute_reply.started": "2024-11-19T17:33:09.122661Z" } }, "outputs": [], "source": [ "files = get_files(data_dir / \"preprocessed\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "bc733d3d-7829-4e90-896d-a0dc76b33288", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.152327Z", "iopub.status.busy": "2024-11-19T17:33:09.152060Z", "iopub.status.idle": "2024-11-19T17:33:09.184041Z", "shell.execute_reply": "2024-11-19T17:33:09.183372Z", "shell.execute_reply.started": "2024-11-19T17:33:09.152305Z" } }, "outputs": [], "source": [ "train_dataset = LabeledDataset(files)" ] }, { "cell_type": "code", "execution_count": 6, "id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.186393Z", "iopub.status.busy": "2024-11-19T17:33:09.186117Z", "iopub.status.idle": "2024-11-19T17:33:09.220175Z", "shell.execute_reply": "2024-11-19T17:33:09.219322Z", "shell.execute_reply.started": "2024-11-19T17:33:09.186375Z" } }, "outputs": [], "source": [ "train_dataloader = DataLoader(\n", " train_dataset,\n", " shuffle=True,\n", " batch_size=32,\n", " num_workers=4,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "e2eedaa9-6be3-481a-8786-7618515d98f8", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.221653Z", "iopub.status.busy": "2024-11-19T17:33:09.221242Z", "iopub.status.idle": "2024-11-19T17:33:09.260977Z", "shell.execute_reply": "2024-11-19T17:33:09.260375Z", "shell.execute_reply.started": "2024-11-19T17:33:09.221616Z" } }, "outputs": [], "source": [ "# List of all possible classes\n", "class Mapper(ClassMapper):\n", " class_labels = [\n", " \"Eptesicus serotinus\",\n", " \"Myotis mystacinus\",\n", " \"Pipistrellus pipistrellus\",\n", " \"Rhinolophus ferrumequinum\",\n", " ]\n", "\n", " def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n", " event_tag = data.find_tag(x.tags, \"event\")\n", "\n", " if event_tag.value == \"Social\":\n", " return \"social\"\n", "\n", " if event_tag.value != \"Echolocation\":\n", " # Ignore all other types of calls\n", " return None\n", "\n", " species_tag = data.find_tag(x.tags, \"class\")\n", " return species_tag.value\n", "\n", " def decode(self, class_name: str) -> List[data.Tag]:\n", " if class_name == \"social\":\n", " return [data.Tag(key=\"event\", value=\"social\")]\n", "\n", " return [data.Tag(key=\"class\", value=class_name)]" ] }, { "cell_type": "code", "execution_count": 8, "id": "1ff6072c-511e-42fe-a74f-282f269b80f0", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.262337Z", "iopub.status.busy": "2024-11-19T17:33:09.261775Z", "iopub.status.idle": "2024-11-19T17:33:09.309793Z", "shell.execute_reply": "2024-11-19T17:33:09.309216Z", "shell.execute_reply.started": "2024-11-19T17:33:09.262307Z" } }, "outputs": [], "source": [ "detector = DetectorModel(class_mapper=Mapper())" ] }, { "cell_type": "code", "execution_count": 9, "id": "3a763ee6-15bc-4105-a409-f06e0ad21a06", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.310695Z", "iopub.status.busy": "2024-11-19T17:33:09.310438Z", "iopub.status.idle": "2024-11-19T17:33:09.366636Z", "shell.execute_reply": "2024-11-19T17:33:09.366059Z", "shell.execute_reply.started": "2024-11-19T17:33:09.310669Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = pl.Trainer(\n", " limit_train_batches=100,\n", " max_epochs=2,\n", " log_every_n_steps=1,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "0b86d49d-3314-4257-94f5-f964855be385", "metadata": { "execution": { "iopub.execute_input": "2024-11-19T17:33:09.367499Z", "iopub.status.busy": "2024-11-19T17:33:09.367242Z", "iopub.status.idle": "2024-11-19T17:33:10.811300Z", "shell.execute_reply": "2024-11-19T17:33:10.809823Z", "shell.execute_reply.started": "2024-11-19T17:33:09.367473Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params | Mode \n", "--------------------------------------------------------\n", "0 | feature_extractor | Net2DFast | 119 K | train\n", "1 | classifier | Conv2d | 54 | train\n", "2 | bbox | Conv2d | 18 | train\n", "--------------------------------------------------------\n", "119 K Trainable params\n", "448 Non-trainable params\n", "119 K Total params\n", "0.480 Total estimated model params size (MB)\n", "32 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 0%| | 0/1 [00:00 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__..patch_track_step_called..wrap_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n", "File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1" ] } ], "source": [ "trainer.fit(detector, train_dataloaders=train_dataloader)" ] }, { "cell_type": "code", "execution_count": null, "id": "2f6924db-e520-49a1-bbe8-6c4956e46314", "metadata": { "execution": { "iopub.status.busy": "2024-11-19T17:33:10.811729Z", "iopub.status.idle": "2024-11-19T17:33:10.811955Z", "shell.execute_reply": "2024-11-19T17:33:10.811858Z", "shell.execute_reply.started": "2024-11-19T17:33:10.811849Z" } }, "outputs": [], "source": [ "clip_annotation = train_dataset.get_clip_annotation(0)" ] }, { "cell_type": "code", "execution_count": null, "id": "23943e13-6875-49b8-9f18-2ba6528aa673", "metadata": { "execution": { "iopub.status.busy": "2024-11-19T17:33:10.812924Z", "iopub.status.idle": "2024-11-19T17:33:10.813260Z", "shell.execute_reply": "2024-11-19T17:33:10.813104Z", "shell.execute_reply.started": "2024-11-19T17:33:10.813087Z" } }, "outputs": [], "source": [ "spec = detector.compute_spectrogram(clip_annotation.clip)\n", "outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": null, "id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5", "metadata": { "execution": { "iopub.status.busy": "2024-11-19T17:33:10.814343Z", "iopub.status.idle": "2024-11-19T17:33:10.814806Z", "shell.execute_reply": "2024-11-19T17:33:10.814628Z", "shell.execute_reply.started": "2024-11-19T17:33:10.814611Z" } }, "outputs": [], "source": [ "_, ax= plt.subplots(figsize=(15, 5))\n", "spec.plot(ax=ax, add_colorbar=False)\n", "ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())" ] }, { "cell_type": "code", "execution_count": null, "id": "eadd36ef-a04a-4665-b703-cec84cf1673b", "metadata": { "execution": { "iopub.status.busy": "2024-11-19T17:33:10.815603Z", "iopub.status.idle": "2024-11-19T17:33:10.816065Z", "shell.execute_reply": "2024-11-19T17:33:10.815894Z", "shell.execute_reply.started": "2024-11-19T17:33:10.815877Z" } }, "outputs": [], "source": [ "print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "batdetect2-dev", "language": "python", "name": "batdetect2-dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" } }, "nbformat": 4, "nbformat_minor": 5 }