Adding notebook folder

This commit is contained in:
mbsantiago 2024-07-16 01:32:01 +01:00
parent db26e0d4dd
commit d34b360b8b
3 changed files with 1215 additions and 0 deletions

1
.gitignore vendored
View File

@ -110,5 +110,6 @@ experiments/*
!batdetect2_notebook.ipynb
!batdetect2/models/*.pth.tar
!tests/data/*.wav
!notebooks/*.ipynb
notebooks/lightning_logs
example_data/preprocessed

File diff suppressed because one or more lines are too long

394
notebooks/Training.ipynb Normal file
View File

@ -0,0 +1,394 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:20.598611Z",
"iopub.status.busy": "2024-07-16T00:27:20.596274Z",
"iopub.status.idle": "2024-07-16T00:27:20.670888Z",
"shell.execute_reply": "2024-07-16T00:27:20.668193Z",
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "326c5432-94e6-4abf-a332-fe902559461b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:20.676278Z",
"iopub.status.busy": "2024-07-16T00:27:20.675545Z",
"iopub.status.idle": "2024-07-16T00:27:25.872556Z",
"shell.execute_reply": "2024-07-16T00:27:25.871725Z",
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pathlib import Path\n",
"from typing import List, Optional\n",
"\n",
"import pytorch_lightning as pl\n",
"from soundevent import data\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from batdetect2.data.labels import ClassMapper\n",
"from batdetect2.models.detectors import DetectorModel\n",
"from batdetect2.train.augmentations import (\n",
" add_echo,\n",
" select_random_subclip,\n",
" warp_spectrogram,\n",
")\n",
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
"from batdetect2.train.preprocess import PreprocessingConfig"
]
},
{
"cell_type": "markdown",
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272",
"metadata": {},
"source": [
"## Training Datasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:25.874255Z",
"iopub.status.busy": "2024-07-16T00:27:25.873473Z",
"iopub.status.idle": "2024-07-16T00:27:25.912952Z",
"shell.execute_reply": "2024-07-16T00:27:25.911844Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z"
}
},
"outputs": [],
"source": [
"data_dir = Path.cwd().parent / \"example_data\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:25.914456Z",
"iopub.status.busy": "2024-07-16T00:27:25.914027Z",
"iopub.status.idle": "2024-07-16T00:27:25.954939Z",
"shell.execute_reply": "2024-07-16T00:27:25.953906Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z"
}
},
"outputs": [],
"source": [
"files = get_files(data_dir / \"preprocessed\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:25.956758Z",
"iopub.status.busy": "2024-07-16T00:27:25.956260Z",
"iopub.status.idle": "2024-07-16T00:27:25.997664Z",
"shell.execute_reply": "2024-07-16T00:27:25.996074Z",
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z"
}
},
"outputs": [],
"source": [
"train_dataset = LabeledDataset(files)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:26.003195Z",
"iopub.status.busy": "2024-07-16T00:27:26.002783Z",
"iopub.status.idle": "2024-07-16T00:27:26.054400Z",
"shell.execute_reply": "2024-07-16T00:27:26.053294Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z"
}
},
"outputs": [],
"source": [
"train_dataloader = DataLoader(\n",
" train_dataset,\n",
" shuffle=True,\n",
" batch_size=32,\n",
" num_workers=4,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:26.056060Z",
"iopub.status.busy": "2024-07-16T00:27:26.055706Z",
"iopub.status.idle": "2024-07-16T00:27:26.103227Z",
"shell.execute_reply": "2024-07-16T00:27:26.102190Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z"
}
},
"outputs": [],
"source": [
"# List of all possible classes\n",
"class Mapper(ClassMapper):\n",
" class_labels = [\n",
" \"Eptesicus serotinus\",\n",
" \"Myotis mystacinus\",\n",
" \"Pipistrellus pipistrellus\",\n",
" \"Rhinolophus ferrumequinum\",\n",
" \"social\",\n",
" ]\n",
"\n",
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
" event_tag = data.find_tag(x.tags, \"event\")\n",
"\n",
" if event_tag.value == \"Social\":\n",
" return \"social\"\n",
"\n",
" if event_tag.value != \"Echolocation\":\n",
" # Ignore all other types of calls\n",
" return None\n",
"\n",
" species_tag = data.find_tag(x.tags, \"class\")\n",
" return species_tag.value\n",
"\n",
" def decode(self, class_name: str) -> List[data.Tag]:\n",
" if class_name == \"social\":\n",
" return [data.Tag(key=\"event\", value=\"social\")]\n",
"\n",
" return [data.Tag(key=\"class\", value=class_name)]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:26.104877Z",
"iopub.status.busy": "2024-07-16T00:27:26.104538Z",
"iopub.status.idle": "2024-07-16T00:27:26.159676Z",
"shell.execute_reply": "2024-07-16T00:27:26.157914Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z"
}
},
"outputs": [],
"source": [
"detector = DetectorModel(class_mapper=Mapper())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:26.162346Z",
"iopub.status.busy": "2024-07-16T00:27:26.161885Z",
"iopub.status.idle": "2024-07-16T00:27:26.374668Z",
"shell.execute_reply": "2024-07-16T00:27:26.373691Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"trainer = pl.Trainer(\n",
" limit_train_batches=100,\n",
" max_epochs=2,\n",
" log_every_n_steps=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0b86d49d-3314-4257-94f5-f964855be385",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:26.375918Z",
"iopub.status.busy": "2024-07-16T00:27:26.375632Z",
"iopub.status.idle": "2024-07-16T00:27:28.829650Z",
"shell.execute_reply": "2024-07-16T00:27:28.828219Z",
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params\n",
"------------------------------------------------\n",
"0 | feature_extractor | Net2DFast | 119 K \n",
"1 | classifier | Conv2d | 54 \n",
"2 | bbox | Conv2d | 18 \n",
"------------------------------------------------\n",
"119 K Trainable params\n",
"448 Non-trainable params\n",
"119 K Total params\n",
"0.480 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n"
]
}
],
"source": [
"trainer.fit(detector, train_dataloaders=train_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:28.832222Z",
"iopub.status.busy": "2024-07-16T00:27:28.831642Z",
"iopub.status.idle": "2024-07-16T00:27:29.000595Z",
"shell.execute_reply": "2024-07-16T00:27:28.998078Z",
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
}
},
"outputs": [],
"source": [
"clip_annotation = train_dataset.get_clip_annotation(0)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:27:29.004279Z",
"iopub.status.busy": "2024-07-16T00:27:29.003486Z",
"iopub.status.idle": "2024-07-16T00:27:29.595626Z",
"shell.execute_reply": "2024-07-16T00:27:29.594734Z",
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
}
},
"outputs": [],
"source": [
"predictions = detector.compute_clip_predictions(clip_annotation.clip)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-07-16T00:28:47.178783Z",
"iopub.status.busy": "2024-07-16T00:28:47.178143Z",
"iopub.status.idle": "2024-07-16T00:28:47.246613Z",
"shell.execute_reply": "2024-07-16T00:28:47.245496Z",
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num predicted soundevents: 50\n"
]
}
],
"source": [
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "batdetect2-dev",
"language": "python",
"name": "batdetect2-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}