{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cfb0b360-a204-4c27-a18f-3902e8758879", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:20.598611Z", "iopub.status.busy": "2024-07-16T00:27:20.596274Z", "iopub.status.idle": "2024-07-16T00:27:20.670888Z", "shell.execute_reply": "2024-07-16T00:27:20.668193Z", "shell.execute_reply.started": "2024-07-16T00:27:20.598423Z" } }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "326c5432-94e6-4abf-a332-fe902559461b", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:20.676278Z", "iopub.status.busy": "2024-07-16T00:27:20.675545Z", "iopub.status.idle": "2024-07-16T00:27:25.872556Z", "shell.execute_reply": "2024-07-16T00:27:25.871725Z", "shell.execute_reply.started": "2024-07-16T00:27:20.676206Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from pathlib import Path\n", "from typing import List, Optional\n", "\n", "import pytorch_lightning as pl\n", "from soundevent import data\n", "from torch.utils.data import DataLoader\n", "\n", "from batdetect2.data.labels import ClassMapper\n", "from batdetect2.models.detectors import DetectorModel\n", "from batdetect2.train.augmentations import (\n", " add_echo,\n", " select_random_subclip,\n", " warp_spectrogram,\n", ")\n", "from batdetect2.train.dataset import LabeledDataset, get_files\n", "from batdetect2.train.preprocess import PreprocessingConfig" ] }, { "cell_type": "markdown", "id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272", "metadata": {}, "source": [ "## Training Datasets" ] }, { "cell_type": "code", "execution_count": 3, "id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:25.874255Z", "iopub.status.busy": "2024-07-16T00:27:25.873473Z", "iopub.status.idle": "2024-07-16T00:27:25.912952Z", "shell.execute_reply": "2024-07-16T00:27:25.911844Z", "shell.execute_reply.started": "2024-07-16T00:27:25.874206Z" } }, "outputs": [], "source": [ "data_dir = Path.cwd().parent / \"example_data\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "d5131ae9-2efd-4758-b6e5-189a6d90789b", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:25.914456Z", "iopub.status.busy": "2024-07-16T00:27:25.914027Z", "iopub.status.idle": "2024-07-16T00:27:25.954939Z", "shell.execute_reply": "2024-07-16T00:27:25.953906Z", "shell.execute_reply.started": "2024-07-16T00:27:25.914410Z" } }, "outputs": [], "source": [ "files = get_files(data_dir / \"preprocessed\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "bc733d3d-7829-4e90-896d-a0dc76b33288", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:25.956758Z", "iopub.status.busy": "2024-07-16T00:27:25.956260Z", "iopub.status.idle": "2024-07-16T00:27:25.997664Z", "shell.execute_reply": "2024-07-16T00:27:25.996074Z", "shell.execute_reply.started": "2024-07-16T00:27:25.956705Z" } }, "outputs": [], "source": [ "train_dataset = LabeledDataset(files)" ] }, { "cell_type": "code", "execution_count": 6, "id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:26.003195Z", "iopub.status.busy": "2024-07-16T00:27:26.002783Z", "iopub.status.idle": "2024-07-16T00:27:26.054400Z", "shell.execute_reply": "2024-07-16T00:27:26.053294Z", "shell.execute_reply.started": "2024-07-16T00:27:26.003158Z" } }, "outputs": [], "source": [ "train_dataloader = DataLoader(\n", " train_dataset,\n", " shuffle=True,\n", " batch_size=32,\n", " num_workers=4,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "e2eedaa9-6be3-481a-8786-7618515d98f8", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:26.056060Z", "iopub.status.busy": "2024-07-16T00:27:26.055706Z", "iopub.status.idle": "2024-07-16T00:27:26.103227Z", "shell.execute_reply": "2024-07-16T00:27:26.102190Z", "shell.execute_reply.started": "2024-07-16T00:27:26.056025Z" } }, "outputs": [], "source": [ "# List of all possible classes\n", "class Mapper(ClassMapper):\n", " class_labels = [\n", " \"Eptesicus serotinus\",\n", " \"Myotis mystacinus\",\n", " \"Pipistrellus pipistrellus\",\n", " \"Rhinolophus ferrumequinum\",\n", " \"social\",\n", " ]\n", "\n", " def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n", " event_tag = data.find_tag(x.tags, \"event\")\n", "\n", " if event_tag.value == \"Social\":\n", " return \"social\"\n", "\n", " if event_tag.value != \"Echolocation\":\n", " # Ignore all other types of calls\n", " return None\n", "\n", " species_tag = data.find_tag(x.tags, \"class\")\n", " return species_tag.value\n", "\n", " def decode(self, class_name: str) -> List[data.Tag]:\n", " if class_name == \"social\":\n", " return [data.Tag(key=\"event\", value=\"social\")]\n", "\n", " return [data.Tag(key=\"class\", value=class_name)]" ] }, { "cell_type": "code", "execution_count": 8, "id": "1ff6072c-511e-42fe-a74f-282f269b80f0", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:26.104877Z", "iopub.status.busy": "2024-07-16T00:27:26.104538Z", "iopub.status.idle": "2024-07-16T00:27:26.159676Z", "shell.execute_reply": "2024-07-16T00:27:26.157914Z", "shell.execute_reply.started": "2024-07-16T00:27:26.104843Z" } }, "outputs": [], "source": [ "detector = DetectorModel(class_mapper=Mapper())" ] }, { "cell_type": "code", "execution_count": 9, "id": "3a763ee6-15bc-4105-a409-f06e0ad21a06", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:26.162346Z", "iopub.status.busy": "2024-07-16T00:27:26.161885Z", "iopub.status.idle": "2024-07-16T00:27:26.374668Z", "shell.execute_reply": "2024-07-16T00:27:26.373691Z", "shell.execute_reply.started": "2024-07-16T00:27:26.162305Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = pl.Trainer(\n", " limit_train_batches=100,\n", " max_epochs=2,\n", " log_every_n_steps=1,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "0b86d49d-3314-4257-94f5-f964855be385", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:26.375918Z", "iopub.status.busy": "2024-07-16T00:27:26.375632Z", "iopub.status.idle": "2024-07-16T00:27:28.829650Z", "shell.execute_reply": "2024-07-16T00:27:28.828219Z", "shell.execute_reply.started": "2024-07-16T00:27:26.375889Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params\n", "------------------------------------------------\n", "0 | feature_extractor | Net2DFast | 119 K \n", "1 | classifier | Conv2d | 54 \n", "2 | bbox | Conv2d | 18 \n", "------------------------------------------------\n", "119 K Trainable params\n", "448 Non-trainable params\n", "119 K Total params\n", "0.480 Total estimated model params size (MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=2` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n" ] } ], "source": [ "trainer.fit(detector, train_dataloaders=train_dataloader)" ] }, { "cell_type": "code", "execution_count": 11, "id": "2f6924db-e520-49a1-bbe8-6c4956e46314", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:28.832222Z", "iopub.status.busy": "2024-07-16T00:27:28.831642Z", "iopub.status.idle": "2024-07-16T00:27:29.000595Z", "shell.execute_reply": "2024-07-16T00:27:28.998078Z", "shell.execute_reply.started": "2024-07-16T00:27:28.832157Z" } }, "outputs": [], "source": [ "clip_annotation = train_dataset.get_clip_annotation(0)" ] }, { "cell_type": "code", "execution_count": 12, "id": "23943e13-6875-49b8-9f18-2ba6528aa673", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:27:29.004279Z", "iopub.status.busy": "2024-07-16T00:27:29.003486Z", "iopub.status.idle": "2024-07-16T00:27:29.595626Z", "shell.execute_reply": "2024-07-16T00:27:29.594734Z", "shell.execute_reply.started": "2024-07-16T00:27:29.004200Z" } }, "outputs": [], "source": [ "predictions = detector.compute_clip_predictions(clip_annotation.clip)" ] }, { "cell_type": "code", "execution_count": 18, "id": "eadd36ef-a04a-4665-b703-cec84cf1673b", "metadata": { "execution": { "iopub.execute_input": "2024-07-16T00:28:47.178783Z", "iopub.status.busy": "2024-07-16T00:28:47.178143Z", "iopub.status.idle": "2024-07-16T00:28:47.246613Z", "shell.execute_reply": "2024-07-16T00:28:47.245496Z", "shell.execute_reply.started": "2024-07-16T00:28:47.178729Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num predicted soundevents: 50\n" ] } ], "source": [ "print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d3883c04-d91a-4d1d-b677-196c0179dde1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "batdetect2-dev", "language": "python", "name": "batdetect2-dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }