mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
395 lines
12 KiB
Plaintext
395 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:20.598611Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:20.596274Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:20.670888Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:20.668193Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:20.598423Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:20.676278Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:20.675545Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:25.872556Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:25.871725Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:20.676206Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from pathlib import Path\n",
|
|
"from typing import List, Optional\n",
|
|
"\n",
|
|
"import pytorch_lightning as pl\n",
|
|
"from soundevent import data\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"\n",
|
|
"from batdetect2.data.labels import ClassMapper\n",
|
|
"from batdetect2.models.detectors import DetectorModel\n",
|
|
"from batdetect2.train.augmentations import (\n",
|
|
" add_echo,\n",
|
|
" select_random_subclip,\n",
|
|
" warp_spectrogram,\n",
|
|
")\n",
|
|
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
|
"from batdetect2.train.preprocess import PreprocessingConfig"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fa202af2-5c0d-4b5d-91a3-097ef5cd4272",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Training Datasets"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:25.874255Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:25.873473Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:25.912952Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:25.911844Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:25.874206Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_dir = Path.cwd().parent / \"example_data\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:25.914456Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:25.914027Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:25.954939Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:25.953906Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:25.914410Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"files = get_files(data_dir / \"preprocessed\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:25.956758Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:25.956260Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:25.997664Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:25.996074Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:25.956705Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset = LabeledDataset(files)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:26.003195Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:26.002783Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:26.054400Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:26.053294Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:26.003158Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataloader = DataLoader(\n",
|
|
" train_dataset,\n",
|
|
" shuffle=True,\n",
|
|
" batch_size=32,\n",
|
|
" num_workers=4,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:26.056060Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:26.055706Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:26.103227Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:26.102190Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:26.056025Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# List of all possible classes\n",
|
|
"class Mapper(ClassMapper):\n",
|
|
" class_labels = [\n",
|
|
" \"Eptesicus serotinus\",\n",
|
|
" \"Myotis mystacinus\",\n",
|
|
" \"Pipistrellus pipistrellus\",\n",
|
|
" \"Rhinolophus ferrumequinum\",\n",
|
|
" \"social\",\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
|
" event_tag = data.find_tag(x.tags, \"event\")\n",
|
|
"\n",
|
|
" if event_tag.value == \"Social\":\n",
|
|
" return \"social\"\n",
|
|
"\n",
|
|
" if event_tag.value != \"Echolocation\":\n",
|
|
" # Ignore all other types of calls\n",
|
|
" return None\n",
|
|
"\n",
|
|
" species_tag = data.find_tag(x.tags, \"class\")\n",
|
|
" return species_tag.value\n",
|
|
"\n",
|
|
" def decode(self, class_name: str) -> List[data.Tag]:\n",
|
|
" if class_name == \"social\":\n",
|
|
" return [data.Tag(key=\"event\", value=\"social\")]\n",
|
|
"\n",
|
|
" return [data.Tag(key=\"class\", value=class_name)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:26.104877Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:26.104538Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:26.159676Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:26.157914Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:26.104843Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"detector = DetectorModel(class_mapper=Mapper())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:26.162346Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:26.161885Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:26.374668Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:26.373691Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:26.162305Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"GPU available: False, used: False\n",
|
|
"TPU available: False, using: 0 TPU cores\n",
|
|
"IPU available: False, using: 0 IPUs\n",
|
|
"HPU available: False, using: 0 HPUs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"trainer = pl.Trainer(\n",
|
|
" limit_train_batches=100,\n",
|
|
" max_epochs=2,\n",
|
|
" log_every_n_steps=1,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:26.375918Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:26.375632Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:28.829650Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:28.828219Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:26.375889Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
" | Name | Type | Params\n",
|
|
"------------------------------------------------\n",
|
|
"0 | feature_extractor | Net2DFast | 119 K \n",
|
|
"1 | classifier | Conv2d | 54 \n",
|
|
"2 | bbox | Conv2d | 18 \n",
|
|
"------------------------------------------------\n",
|
|
"119 K Trainable params\n",
|
|
"448 Non-trainable params\n",
|
|
"119 K Total params\n",
|
|
"0.480 Total estimated model params size (MB)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.59it/s, v_num=13]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.54it/s, v_num=13]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"trainer.fit(detector, train_dataloaders=train_dataloader)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:28.832222Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:28.831642Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:29.000595Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:28.998078Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:28.832157Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"clip_annotation = train_dataset.get_clip_annotation(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:27:29.004279Z",
|
|
"iopub.status.busy": "2024-07-16T00:27:29.003486Z",
|
|
"iopub.status.idle": "2024-07-16T00:27:29.595626Z",
|
|
"shell.execute_reply": "2024-07-16T00:27:29.594734Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:27:29.004200Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"predictions = detector.compute_clip_predictions(clip_annotation.clip)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2024-07-16T00:28:47.178783Z",
|
|
"iopub.status.busy": "2024-07-16T00:28:47.178143Z",
|
|
"iopub.status.idle": "2024-07-16T00:28:47.246613Z",
|
|
"shell.execute_reply": "2024-07-16T00:28:47.245496Z",
|
|
"shell.execute_reply.started": "2024-07-16T00:28:47.178729Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Num predicted soundevents: 50\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d3883c04-d91a-4d1d-b677-196c0179dde1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "batdetect2-dev",
|
|
"language": "python",
|
|
"name": "batdetect2-dev"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.18"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|