From 5cc5767eff66e75c51fdd1bedd99e69386caf10c Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 6 May 2026 12:50:53 +0100 Subject: [PATCH] fix: rename detector heads and refresh bundled checkpoint --- src/batdetect2/api_v2.py | 9 +-- src/batdetect2/cli/finetune.py | 14 ++-- src/batdetect2/models/__init__.py | 24 +++---- .../checkpoints/batdetect2_uk_same.ckpt | Bin 7614445 -> 7610085 bytes src/batdetect2/models/detectors.py | 61 ++++++++---------- src/batdetect2/models/heads.py | 11 ++-- src/batdetect2/train/checkpoints.py | 1 + tests/test_api_v2/test_api_v2.py | 6 +- tests/test_api_v2/test_finetune.py | 6 +- tests/test_models/test_detectors.py | 36 ++++++++--- tests/test_train/test_checkpoints.py | 10 +-- tests/test_train/test_lightning.py | 2 +- 12 files changed, 97 insertions(+), 83 deletions(-) diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 53230c7..ec068b1 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -19,7 +19,8 @@ if TYPE_CHECKING: LoggerConfig, LoggingCallback, ) - from batdetect2.models import Model, ModelConfig + from batdetect2.models import ModelConfig + from batdetect2.models.types import ModelProtocol from batdetect2.outputs import ( OutputFormatConfig, OutputFormatterProtocol, @@ -88,7 +89,7 @@ class BatDetect2API: evaluator: EvaluatorProtocol, formatter: OutputFormatterProtocol, output_transform: OutputTransformProtocol, - model: Model, + model: ModelProtocol, ): """Create a fully configured API instance. @@ -128,7 +129,7 @@ class BatDetect2API: Default formatter used to save predictions. output_transform : OutputTransformProtocol Transform that converts model outputs into detections. - model : Model + model : ModelProtocol Model instance. """ self.model_config = model_config @@ -1177,5 +1178,5 @@ class BatDetect2API: parameter.requires_grad = True if trainable in {"heads", "bbox_head"}: - for parameter in detector.bbox_head.parameters(): + for parameter in detector.size_head.parameters(): parameter.requires_grad = True diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py index 33b1c8b..f0f4791 100644 --- a/src/batdetect2/cli/finetune.py +++ b/src/batdetect2/cli/finetune.py @@ -47,24 +47,24 @@ __all__ = ["finetune_command"] @click.option( "--training-config", type=click.Path(exists=True), - help="Path to a training config file.", + help="Path to training config file.", ) @click.option( "--audio-config", type=click.Path(exists=True), - help="Path to an audio config file.", + help="Path to audio config file.", ) @click.option( "--logging-config", type=click.Path(exists=True), - help="Path to a logging config file.", + help="Path to logging config file.", ) @click.option( "--trainable", type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]), default="heads", show_default=True, - help="Which model parameters stay trainable during fine-tuning.", + help="Which model parameters remain trainable during fine-tuning.", ) @click.option( "--ckpt-dir", @@ -127,11 +127,7 @@ def finetune_command( experiment_name: str | None = None, run_name: str | None = None, ): - """Fine-tune a checkpoint on a new target definition. - - Use this command when you want to adapt an existing model to a new class - list or ROI mapping. - """ + """Fine-tune a BatDetect2 checkpoint on a new target definition.""" from batdetect2.api_v2 import BatDetect2API from batdetect2.audio import AudioConfig from batdetect2.data import load_dataset, load_dataset_config diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index cc3ab69..ee96d93 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -62,7 +62,7 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead -from batdetect2.models.types import DetectionModel +from batdetect2.models.types import DetectorProtocol, ModelProtocol from batdetect2.postprocess.config import PostprocessConfig from batdetect2.postprocess.types import ( ClipDetectionsTensor, @@ -149,7 +149,7 @@ class Model(torch.nn.Module): Attributes ---------- - detector : DetectionModel + detector : DetectorProtocol The neural network that processes spectrograms and produces raw detection, classification, and bounding-box outputs. preprocessor : PreprocessorProtocol @@ -164,7 +164,7 @@ class Model(torch.nn.Module): Size-dimension names corresponding to the model size outputs. """ - detector: DetectionModel + detector: DetectorProtocol preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol class_names: list[str] @@ -173,7 +173,7 @@ class Model(torch.nn.Module): def __init__( self, - detector: DetectionModel, + detector: DetectorProtocol, preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, class_names: list[str], @@ -224,7 +224,7 @@ def build_model( dimension_names: list[str] | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, -) -> Model: +) -> ModelProtocol: """Build a complete, ready-to-use BatDetect2 model. Assembles a ``Model`` instance from a ``ModelConfig`` and optional @@ -256,7 +256,7 @@ def build_model( Returns ------- - Model + ModelProtocol A fully assembled ``Model`` instance ready for inference or training. """ @@ -285,8 +285,8 @@ def build_model( config=config.postprocess, ) detector = build_detector( - num_classes=len(class_names), - num_sizes=len(dimension_names), + class_names=class_names, + dimension_names=dimension_names, config=config.architecture, ) return Model( @@ -300,14 +300,14 @@ def build_model( def build_model_with_new_targets( - model: Model, + model: ModelProtocol, targets: TargetProtocol, roi_mapper: ROIMapperProtocol, -) -> Model: +) -> ModelProtocol: """Build a new model with a different target set.""" detector = build_detector( - num_classes=len(targets.class_names), - num_sizes=len(roi_mapper.dimension_names), + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, backbone=model.detector.backbone, ) diff --git a/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt b/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt index 49b64a436947f4f9fe2c8171ba78464ac4e016e1..b849167acd928e93911d2f322372a49a30011edf 100644 GIT binary patch delta 18024 zcma)j3t$_?wYDwCiSr_HVkL2GCw@PZ`27lbC=RicI0=d&;5>}Hme;mqS<;SHNgS#= zEu;z1QebZp^!Cz1=>ysVy}f_k-dm{X54|m~7z#9y6haFvh4Siy_J{WN|IUoG^6uLD ze}}U>JLfy!oH=J^c4utfd37q|jfbZ)ZdsjCmbN+i=t?zbMaXkPc8$o=VC|r!N*jWs zzRr&Jv|IPr-rB$5){`seMig6RZ>@IghILu)K!6Q~l#E;7DbAi#S=l=;EiEmi2ADJ~ zE16nhy8ZIGz3Hf{Q|9$%;O-8Fm8|ZQsxp6ncUon4I{uyPPFIx$mAIos<%Ku6-2tCZ zcB`I%KeXU-HYhWB@G_5EU7#$|M$sR0;FyxOokkuM_6$Rb@kuQmehTqCg8STB2RduxpWwc?MUTm3l$Y z5GQCPf+lkH{*3(wXp7P=pc~`RO9{G(pha^IW^%W*i8*U7$)ad(2dJuwl;4`O%@DaZ zWhWx{k2%GNTw$g$s_lvsx?`E8(0wWMvWi4sZOCDLoze>j9IA4V9B`ac4v_ln*30|btegZEvHTRU zJeFwV)2eb^k8(Z6#Ip;U^WyX1ls+i9VTG0EjjD1}kMh~x&19inyL6#l z`J55os@x2G)BJKo{qX#~MvvQ++u;SKPLw=8c}ls1Jie1WK0E(_5!n`FNY>soc=;=MiNY!673$Kf=n_1-s4B1ZD8JUuBAbt_&9y7PF`Bn2uTz7Uhz8$O zmEZR$Z)y7%653vSRAS28MuT?c5ApWiNwoJawf7!{aMu!J2pj8^KZ_9lMGxT-An7Be{QMMsBdHoguXSq0Qv*jy@n~+*u<7` zJyrs?93{Ac#en6Y2(OtPcXIq<_EkpSX11E!Y2C|+*R$*zqrQa|3;hMI|7cmCQQyj{ zh2FQE>~CBymrWmAjcu%f>oL4oBT9OBa&q<>nc7*KU^<#p1Eza&dJOueI_41i`S$hD zx7v*nYN}^Dh5k!+^5j45y+->6<`nvh6{O#>!WfCBMn*43WZ)sLe`bX-5=~94U+A~3 zBzxCNVa%vnWb5mI_a? zJ5Z#tbSLgKmiqICjl2!)E^en?$fI^Y${RBBHnOjCJ@I}6r8w_5Q6%2`a3@|Z-)|U0 z%}wk9VTZ^)DD1>4RbF6>Omj2)502DsFQCXgR&dbhK@0nlV1fq^3nuE>BPfywKgOLr zaITSzcCG9&-R|+E-B06o{~fn`eT~Z)kv8@ex6^vpQbdlf9X9f|v*)>0@nEk#SbEG_bb>(|tu1 zV0yO5I2KzP*&l^|Q}J5pj}#v=deX!`68a~L$^K8pCjMq-!`m0Ls=tKnPnMYMTi61j zf3Jk>SC;lRPhS^XTG=wLrAM-g42#h>M%Ofj@s3*FT2Xf z+|G)*o->!Am|!kNnV8FPC+5e>j4^4gW7XVJt1Yhwvs8YiAttT$tU>6XD5pNZReq(R z<<>9o$lTxq_HKsxU^gwTbN(dQTTh` z46V(qN9g~`^((6+!%S;!VTZV$#?4`r^l|fVRr`#6Y-Rl%d0X{5G_=0j_&RHCVeIBr5X2Z5R zcAOh_+f+7=5}fb0onj|Yq@Z4lI|cQgn(m3|4Yt*@liY4^8Vo-r3{gJCJ{`BbPFPm; zu~pyT{vppW zyTO(=0qu=ky8zmdG#2#9OoHi6sN9|QW=0w}lbCzNqhjk7W-|LM=lmSEngc688e=!( zkwXvclhqE%JvtEZ%k1-sT8}>%ReavCxsQDT*UAXHoo?QnZQ941b9cvAoKAcPYP$=0 zgQGG-TQ15YyOXLjdb0ozN`s>Ai&U2$j5+1yq#o-lJULE*vaUdF|815k# zyI3F`WOtF3@b$|CimVz58EB6~Too-{7!s5)4s81B=nX6|!$=o2trw+FY~ysCpTyr9@RS zmiODVi(6J~{uUrzTi*u;KVd1%voTAM@ireo5#4`~Zj`Ee$4dDjQqL4q zP8ANiXsojD0>yhX;PH9x-6Jd|qw6l)fte@bakNt0(_P%3H$z%NCaa`fuPw6IlLK8OV zmi_F%k%evc+@C?!we_Vt>F0~Rq#c$#e)e;l=GtbjeS)+ORMJmxCN9CiF-*%Rxz$Cg zMfM}?DLhlOzq+kx)zc`4WKI=Qhg9|qo+Iq6O}nsdm2JwVeYkDyx@Wolo5G%t@M9kT zVBnZ*$j_dGj$6~RzJ=VC)X7AF9uxOSh z_c?Cy5?P3z({jUPzJP~VPV|9kpH&UK@%LmO(JRUvsoC?#SDkiMnYm&ZWm}aO@oqsGd+EcvQ*l@rOff9#mn; zD?*t?REbc|$AecyZ~=-Dwve=XLyJglYG*Ok3C1uYa|s!U!4mI)PSFlIZ){C+leUQ{tDs|?yp2K!d8)1Zzh-2CihoU zonR#Q^T;6SzS}Q{kj#7n3HJ+7jIcEX(|OmD8oZ0AyI-hDUHSGRsB~Y8$tvM%2^k4r zOHqulGScdt<)k+GT0wP!k$kNrgQTxPPtcE9Q$--*Yc+}yRzonIcO9w0>zMBAdQ|q> zF;<+OpeLj<+2;#~3c)rK4tYFm1JrsaYKbnKh{DZ(TP?RBdm>JI&V9v!D|TkMlEw!k^NGrbYC}-Rl?WJWF&muf?|Yi zC9TfsAhpTYZB!>1$=41tNcxJ&NdezZAmQs}C`Q;0g6X`Sq&EB7rJdiEYu^c#?&~hH zO8B~)jD)YYRs2ZSwUn)d@!O^#~ald>z2@*he7Y>oq7w zn2TUKZ$GKQYgjWRR8C(rz~diOF@y)8(A{?vQQUpkAn^$IWfUW9h_re$!=yI3KSFha zk=*x?LDGH9L;T0Q1bWk^{cU@Gwwb`Mb?zy)2Y}Fh3dTwJ8%-fw;S)nK!a}6geNstn z@+nMpf{}bWN(M=vFjNQd>2ZueBBaMrjIeQn>AWXM4PL{vx>hULn`i$7RJyMbvWmxV z;*(?~e4Ris!cLM_=lm3@O}>7b>I5VCdL0=geWmz~53}nDBz!%EVuXE$U^?$Fd~;4DH>$1@;@D)1AGM%;L_@yNRs+Xw#N-7Nlot>pM$U@|_61so02T1ZlgK5J=l_ zXF;~f?Js0#kM1k9-#(o^nq<$lf%pO^{;0ELLMQf*yM3~U-6e?WUH@eiu_q%_y+3!8 z+SK4ZsN=8jedBHw-$$WB5QVVu@UX8!rFZdOViEzQ50()r1o_{|2(6yn zNv+HUit+Dkoo_Fn>_w|b6vmewEZ?i z`^Ob|S-%4!!rsi#3a_+J{2mW|>@8d?z3I4mm(V4jE>(2l9|YT}#9v|j-N08^zNGQB zg3nw&%lM%F0fBkWDc&-FD-6OHe+VZHE;>63vA3b=V}HoNaa1M$bUx-8#Qx|VGNo;q z%r^ik!_G8ucJ(fF5%wOg;$TZCeF2F+`rqfu54e*0j;#`QE&R;qkEra<5`T~|5*`}D zEdCR!&hUMoG!Y(#XDZ)ggxH@s&0jbTrAGEk13r0>eaH=HQz3;1JpuMtJomA`5e{1t zx`~X@PGrn;+$B&SK_T{zEHEI^afUk1K1OBN)}P0>-wO_L3%@IVB>8u){Y4}H;aQ}8 zZ(R8guKeu~|H1P8KX{HeOeb2>;Gj#U(eh8!dq2+57G1S^A`RQKuC3~c*Sa3axz@|+ z0p&kl!IwDEk-4k1QO3#WFwPTMI&Axx4OjFsgKnaiwAAr)7}wE&FN~wWInbfO=^dj{ zIp*;%-~-U(A7+`_gI5(y%muQ`!JEfKi?De~R+Q3?Ni4{+tZWo{9Al~yn@Xj5LU-sW3UXd#n5$UfcKIac=_?e5(}?a z5dKgQdIbd^UBpR$KckLG!4O*t%NxS7_QKVx2D4F)uw}_cE`mc9ZH2bKw<>+5c5UyP zovWaauv}cdqUF^n#}u`izOo8RcNbLrhl6kkfIo1<;KvJ$9fbM}HS}^{ zLD?bVp|=u#(*sP19Y#eTJA!M3_2J50%D*Nz`x-)$13n3nVy-zFYSa6*{%h8DOVFLc z*&eBbzJMBH1E}d^Zd|eM4dPbm<-wpq?lr{1E4nXpjOhNBAwrpYJWO?0^1b% zt9PuxRlIqG;G#1g6eG-wtN0T_kKgSJ;{`v8ntiv}gruXgOAQ2FqwF~9 zFu`d$40(L`;J~|UoO94j;y|3&s_X>WLNGD?qwHFGLRYiXaHh+~U4EA=X5uHbC;C?} zi;$WY!BMk;~D5LQpd zzMKxmS#b!yz&Y+ti9`6~eD|cpAp>!~dsE_&ggD>VQsR&Yj?)}!Lwc$9!N3YjSCNHc zbhS7d`$qcxv-cEv;7q4u-%N+tAQ_N3ld*4RR^K+e9a0nT(WP^%ZL{N$oOq8mr@}Sp znQ=H$;`GeKI#S~F%s6(W#3eE_dn6l>6duXbj-7KM^B`G}`H%&Wg^)#%#gI!NOCU=j z*^p(B<&YeR9kK$lGV0j5s(RLlH6(I783Sz2>^LMio`Jn7afu91#`dSgB{OVI^;u6_P5nUvTM^~(#vqhvSk*CwKP|;*8 zRD6HwtkG*q%)V&sj^f!Dm8L|(qOm(u!4kQM#=e*emPkc3_OGd6iA+ReUrGf_Bx3f! zH6aHgaQp>5Yv2+on2dd;`2K&HJ!pvu6OA>O%pSDnL>{8CmQ=803R+Xak|}6Q1xu!& zJryjGg4u)CjPUcIO*$?uslH{_Xf!9{AB{bks>Q@Sh{m2u1xw^08hbhwESZ95Qovf0 zDL9)7mPo5mw-e`n0L7E{gkXA?=q#d#maw%jJWHV$7WGlo0*#_x= zY=>M1*#YTZYdm;NEmqVP8{g58W6_6_-2Ow8Lu7>nN4nht=4nvMW z`XJXpT#$Z<1Q~$1A%hSZG6Wfhj6ggPFJu(rgN#A^kN_kIQ6Mbpcq24>I9n5=crpf9 zif|@Yl1ax~DRIfM`)*2HV(d;j{*)4z9ISs$iAxOD*=J*GVnLsb{XLZ8+#k(UaNgx{SuO z0=~DIukUT5XS3B@TP{BKb8la!=H%D6G_}>$wYD|2HMKOhG}kvaG&j*E`Ih%F#p;%+ zs}?7>%V}w4(F4oW(rIj`mZ@dg>#)gA{&ek6Gq-tunYu2!HVHxLSu3h-xmueoPA(JB zo*oO=9n017?2-h4-(R$%-dnDgm{86G7Ow3%7Sz-=7F2Wwl*@v8dj?d$1+~>~ZBeqI zv>8wX7S!)%K)EfbOIKK0>>aeAJ`dCqV^pW2MLFtz;HLbRC)ml(DSW(Uf3+fdsYN|<~vi0?W|Ig304GqzE?P|WM z7tU4-=b1cnse^v zAWJaXc`!^JnA&G%Ja1*}z1$3{TVrk3X$GAF$Q1jjgGJ^M^s*J@JZwSLuC;Oeo=2bsrLQ}ayY%lSDoV@-*r*>e`qNh`>Ci<#+#89-&x@0J_p zV|A&eIcK>UG%*9H!VG$T22iCLR8?keuF4GhygGk|_% z20b?e=wUOcqQctTBWBPiW&r)z40?73P|OS}ue3CG&H}m?Ak!-2d{msRPn$MVR@Btv z7F1c4mFuT9<`wIN6*cvQ1$EYnay}`3?=j7<(rQbKQ)exxaVyF>)oAwX87pe)c?+tf z#u8NL3(@p-hFNpm3YvP!40_rMa=u~)6|b{2H}z{X=$IAc{GA!}lod4fwi#5k-qKv} zAIzYmR?yU6%%CT&pgTS^g9OsF#-XbceN>e&b;M7X#oo0NxhJ#W@`(9a^y16~6yxoNV_O_d3X<9ixQKItn zbNZ8&e$(pNXrMo>Ul^#9X7^_(M@*$N$Q_F;&Fhe|JFUvo(<+sjv!_)^IsIC?Rhr+Q z(Ox7=xoD@oXrHv8-Ly~2Bg-3-*Duym>!o5rEeTUgiMo)eN0Qg+s0~tupjL*dRYa{O zYK^%=M{Shq1l1a*))TdXsMnh}>!?lAGC^G)rmi6BF+?p$IVF|*s+^HhFf)Uqa~wz& z6$IXpvPu`C7HKU)^i4{MDOt%+)rW7HWCtdvE(CUY>hblF03;(Pb+u9_9I(mKDdd1{ zpL8lYa2nb`sJF>dSBJE{b2$(7tn?YRV%SIyHQ?MxowP$3afXd_lMxp=x+`s6nXW;r z)GxT6FxN}m0ph-w)}iOtOT&UIg}IEl0pdERY}Rudq}_tMC(J#axIZTDOH)oM6h5r# z^}=^5k;zQ)x=#pbxWPrBWta6o#0*{=H5Wu8^J}Y zZjz;&JEU8b;RWJJJ#*?j<>0h*Wuh>pLAq6!E^D21n=o{H*w7th=uTniE?K&}L%K&P zD$*L7sq4ShD*X=qzhhd~G-~@^S^8s#bf03M)u@c*&QS8wvx@HzHz^`+t(P7^lZf~c zG$|s!db&QS)&}Vz7`=Xa8G`!M^bPtJ8>PnqE}b!tw>Tk7Pj^VqC_4&-lNV%UC}+-` zu`)b7pVjxbNqVk7V?z>Je;%!)uupn{diy)I{nWyQVCu@uJR+#;Eg#8Y&-xY1WCrj^lNFS)t6^>hSb~>5#q2B9d z(nsiV&8&F{$Ld+#x)EGoEB#x<^9wbekIY)73t4@g^bHK;X3_{gHgmm2+pBsjOUB#` z{7_~s#XHlk>qmV(o62!a4Qv`pXa=(bn~oy7!)D-<`rk3TQ*XC{&Ea~=k=gmM`|9kI z^>`!875LgYg!j(buE(2LiNIgt_*Zkf^msF?6u4(D>F=NGUL^V@23ma!vvM3`jn$*1 zhG%xx2ECCIs6d?&EH9|XcqC9D4JB6h;b0|VH=AWjIKnpul;Gy{xp_}KQ zqQ{r9vjsjsmppCGb?WfOT6TfJZ{_%7xeh&E$1W9k>jKj6TF|G*t?ZD%U*h;z3wm`K zXsl<~ahx)6JxbvW+<+ow;6{8>1`g)=bYrNof!!wb5TV$R}QRlO%B z^_~ptJr&jq6h=jC8GDZFDW4XSKQoGYbrEfYQKoQ~B73m|>RL9=pIE|gRQBudw zvSR%_&}3yF2us%$(|rD9v3_hd)w9n9zO1AW_$ei)>OE;--w6EC64L*yB+7mxOUA1l zPf1rP>4!?A^qbgJfqzs=`dJI3^qbinfuFUI^skP_TUf5ZXD%ZBrbSWyFJmPFzj_hr zKNy8K*Ro21S1+yuzJ2jYy8boSF{{8ISxoxxM&nktT;TRH!iUTBvr}_DYZLg39RI3J z-~Z+Ywoc%^<)nXh`6-Rs`MkXm)#?0k$hzuceml7EPX2MNM8+)vj~7rH<_t z`1e(XBUsz8pW#+!`*7`GXX6?SvHkey%)sK=G31t=w0B|$P({lyb`A>VoxqaVxhRI% zc|@k2(oQ!Ec>F$gJ}TNxfzA}v4f+EC_H)j=09S7Wo6WJ2KwxRWbGq9>G)J$GUC5aS zwaj^8D@%LavfCx20hN6bXJ4#kFNm;RgU&#})9Z0Lbt(#1xg@M|BKI>DIz#>*_aOTP zM=#YHoEZVb?fV%U<*duJtZZx_axLs~j$ENd3W1bZ*e^NqD=o4ZNQH%6$&o`^1iR5` zmGf%ZRUA63g_aPc95SU@*wq~UwHB@C=&R=G7WNyCUZbMPflgtzj$O;4>r@D-5@w%D zh1u&ldV>}fW@n|t?2R0~NkzkEtLCzsIdY4Nm;&ma`htvf3%ix$w`uUsI=I^~m)*{h zJ2Z$QEsm^uXU%1Ia{MkWey}vMbhlX8-5kG1#o<#O7P_$fTaNrrMc~t<<2hO=y<)DQ^Mao8A|HzU1R0OeH3Vs|Zi4lwGa7k(FX$k8t!+ElO6KEbK2F z`KyYAEmme(*<&1fT!nZ}t1m7C~TG+E3drrl8I=SVkbJ_D8d_e^x8O_TP*?f`XN42;xoP*sHGWin6UsiD# zFTp5WpcIu(wXnZ&^c59F2rI(qYzuppqpxYvsxUg&!d~a-8(Nf}8?v0Ggo>wF*qfa3 zcP*na%*eN}e{l3IBQ&c(3cA@IM{m%_-sZG-v^2GuykX|&q<1;#JuRu4w{*BH%fjC0 z=m%Pqp2P*p;WCRz&xf4xk%qya<-=v!7WOg6K2b45ut2%JqS(Sd<;Z6$!lQS4MYV-&GdYHKp0GB%uz$E<@XMU|^ zmKrjbI)gnPze8qDU!d2|h63zAT=g5SnvN^I%PN$6s!9qqq;EOtznm22R1v3mlC!jy zeaBhFJQ=SkoSLqo1_RxlnPB(t8Fn+ru#-8l7{Du@WpL_Lv$DC`GBSoRpfy#AB@0pW2rvq@}@YF+Z@E~)unOtiY*P>;D54Kq*ZaH=9Mz`GN zbPaU-eQq{8QtR;z2W6~(=v0@@L2diq&S}cq)w5R4r5h7pgSdSzyvZ5MLRE-aNH-1k z2c3J!^Le=MV%fNshgc4Mcr(q)l}qM#%qJ$f=|>x`ey7jp9t^Nt)V3FOPDA~Go1x0V zsIC8%H1`We*l6Lbc78hhL#jY|QBSPib>_G(c?JnHC!a<{T3+ifM- zD9GGYaM(eMJ5~=C?@G64(BpGE8Ec5tfD&*yF=`u8)5V%_EpJb5Pj(}n&2$rPX`lr+ zqNO5z483ICO9s8d#xhci*jP?CB0N!=|-9un-8yPM^%yp?Xd0OPABn6@1sw=k)kkJFot9Tlp7%x^Nq> z<~yGCX63B9dD$B{`UVXezGc~ndrH%7b;Z_ACjP?YBowWJE-mlM%uhEdMUQZu$bA!FyAiW(XcV&Xe`Z%GV z2%-xpC3QPx=0SnSj5cO{K8Nh=3otJ#yVwA(AvTC>Ibwx*b-=xc4VjhvhMbi?Qll-} zPUj$2L4LZ0B(=fhFbYAZh^kX|^*cPhOhR3VFriCp_88Mp<8-ooZ#yBV~!PL>2%a$+W0X( zB6gSWVP_DRCgh%7uH}A4Z>SJeOtz1W8K_=olc*iTXG6|wvoiXpa-XllRC#$6{Zl&JXt1}kP}Y>hY4TX#QUlYSgf zCK_`s=b$p^^om%XOSZzXR9aW27KsR+kLnQnIe}_J7vRpjU+Gy{n$6z^VHcf)E_M+? zUbAv&WqxKLEkFF*-s=Z$;gUxxC-vCkDyM~m)eVlSF>7daMd7~M&(nht>Vb8D+1Mg;H zh&T5>IS;8tzvPGW{zT67l?ML6$9F~sg6RJHS4g7+(L z*finnb~5~9zb6Rc%v}`|&ij%c%s1` zXOKNZ3=z*~Q4F!?NMDV^^8`mZ@&fAkm<~GvGu;^1exJv~UIeDPa+FM=@1WDu+V2cF zl>1g>OnC`4n_3rNkM|t-T)bR;7 z>>bSbdStlv768??x5 z)CY`M{9_*xN%Y}k6hrJ2qN#oOlwjDMp`8mpQ#Kz{V)-W^)#J}eD?E@!{zXc{*MFlJ zVqXxd+WeB>C||##I$@D~{hAaeeZ^?##!K!$L=wJ!gJOt%OElH)e+h=&W3|41hf3{m zfSvw;?Do)b_#ZH}2j7#eNDq>*HVOB@gd(@rfAbQXJfOv`LJk%Y4=Pz-mSEUjp^f%RRPr~3eGt1i z8PSF%4?bBd5Y>MhS&I0-iVO+=+fc-wgHW}d;|Y%P{{*TN7RmqBq%i3}MvKqyVrz&b z{9lV=h@D6@)$Tfiqy2A3CHG$)EwCg92c2v^5Y_(;WGUkRMlvM)--Kd_*$GwK*-UVh z{~c5(ERz3QNMX`{9{&J4iAci#ttf`r$wX7_{)k}MU8#*ft#mu-m`4-;0G8iQ7gBHv zK(z~}lCek^P9s~Q3)@f(u`WW@rnVCt)ddID35&s?n^T$J!*+0WCs(KMg?ik*&fuWT zx_P~e>P1IsbK~@3^~@aItvKjoJ;3~5 z*K+Sqm679%YrXuf&tIGTeZ;49K2`E5h(E^s@!(_Zq!i`Q@j1(fxEUXr;cMz$o*ucM z`El>1otm4kKyX5cEns**H4H4oq!i_|<1HhM-vqci^$^=8Y-;$KlZ?vtbbg`{=nwYx zV)hE6>Hy#RIfH#eSo`oDNPz9+R(5eK6o0qR**)m)VY@k>mil}j&-UQHi=9q9Y*_G7 z9nK4FRE9je9fEcS5FRfa7Ir)7^hDmnLMh4}CoD)lGevpggav&gpsZJ?wAKBLl712=`#%{7D0%AAtb}jY{_OSS#sC1Bfbm zik+m3UBE5f7!`yeXJ8<}E(F$YhWUdA=J|obMKR{R!r&i-LGOjigKKh37b`EXDQH$e z4zWunZPBitmZSb5MMv8at*tO! zuAIHLVEz@{_Ahbuie`R=A|*t*e{HViO5E@#)zu&L4Ty1kNcm`OzV#}ucQ{N23ndH_ z=McM^bm%w(FCgag`1&0DEcw^UiWBoje#6zSnN*__2$w&A^E?L~**bBu*T=3U6O=UB z0k7G0xNUz_oU-t*qp#;yZXhdSPagSVncc{dn>a#i=kS;FH=}}%D3upZ%slZH&bgIy za7pUMH{|R#V%W$V{P4D&gTwt!b~~>Z)tGj?;(X^0)Ov4LYStBx+=*g{-Gytpmp5Yz zlf9yyyE&pW?~O7lPE7fvf~oEvZt%CHCJtL%E;m1q$C=RYP%9!G#1YXL?v$t2<(q!5 ze7r8l^amxay}0CFRJNyv&tP|Y@MQ}7qfYBS<;3><75Ag=0AEt|F=r2*Cpc&vIdDe9 z9sstHzkmikLmrvFS;2J2r!rVOqFmjcXBt-?XkTD@PI5 zW&5Li@X&dgThu|tI{;MzT7M>{7ydXp zdwa2i!9({Ppb&eWEQ;xl9ufWmR4Fe2H;xX#+(CMqo z&0bP2->}H^vT|fY!MeYJazHd-Aw4<o=wh$o z8e*^G$|r!$8|P=fK}>RB(1|d!Hx>KFT+`o`fsKVD{{VCVXLp1S5BlW*dkZyP>}^~z z^}U0SaxV`UJtAIRXuKjC?{W^`*vCfWJ<^H_!24A95-w|63&T+kdS9k{IKh9skpL*= zn@W!TfcT?fBdcgZ8asG}A)@33F|HO?LA)*yWSRueZSN^#vXUe~ThuFWBOnZ*$3#G!I zH}WNE@RZ>EgN}c{Ld8b@tn)8Fz9t+C5XwCJ58cr77I5x#J7oW`V}N~wI(&sj4^XcM z-vQCd;J4fcJ!hPVU#OAUe@PdTgqe1LeMfhAohT34Gnf1iH~0n&iz|N2!oM&K@Z}Zz zo_I8&4&WV&Q{SXi#k9G&)P&^o0bpI)h%ZZ%aT{D}Z=aI>WX@efx9mttx@CuA@5nRP zBsKFt=_nNQarjRb3;-qtH#EMvsh|ZjKubBrsxi|M+DlNHW9>Nf{fzEVc6DxZ22RM0?Cum4?S9UY?Lx zl|8;B0o$t0s;xE=TbtE!Ya+HbtG32OY)w`ZMc#^>^2qORZApVnfuut+AX6dJAk!f; zATuGeAeoTakU5aKkSvG=G7pkH+S-y+6)&n*&A1pFZ^@}liOa^wP;+2AHW6DhI>yFV zC1UGF!|FtA-Dp^sfL*T{4T*A4kFoIoH~Q-}<6!@|JqI!$k_%Y?$%EuW3Lu4$qOoy% zF{b%AW4vB7R>o|n7RR59>UA@~X^Gg{OxwB=v9*~V-=2uA&9rStBDN;e3A2Gb^K9IG zq^BzG%-o>K@&0jF@z}VlPN>9l-v}TtV#Ch_&Ev4nmmq=D?ncM|n7{|k^EWzv zUjniwf1~5~Cm`$6_dtAdlP-Nn5|B0NOBA;zWRJ&ff9r##kcE&%kj0~|50=G!zcp!6 zIX3=aS^W1~lO~t@Z4Z}0${`hy%2C_HRh98F)}(nf$Hqa9`zmeH*H zsO_&hgLI#joDsK#MX`b*AlTc<9^Kck3?+U zKz}z8TQkrTJ?G7lm(bXej{;S3mrBi=K|N~wGQjX?{@{+LIkJiv`~J&7{AEOoW(1G! z{W=hLr__QBiuwHeuLD)-KQKF}=M<5p^XOiaTy1Vj%H=z->DoQ%=;J1NzG)tIU)h%? z%UOA~%bMz28mvtXb&dGf+)~@pSW5>bUw*aeuF=|LIcLV?kqM5dMzF4Qd8G3`b4Un7wzVCv1UiFiM88#y#epbSf2d`1K!#UdEsn*7$=se89c~W z#1iet8WBH=B~G*&5m!!)Y1-aqL_8QvoH$+(4T5R4@?b`F%D%?Pa>KM37mgYd&yFS9 z*X9_6>1Zr*V!aWua(YbD_RU7bGh&GoTaAcM#uDwP8WD?T#56szT@Vf8SzkMPOS+sF z70*tGk>!VCE!%e(5$DZ}b-`0?5YPTtqJ6-KcyBCmV$g^o9$04j2%djEFBA5SxvN z)mgDkw-^!6G$1ZBB2E|(mm3jFEHO z5%E3);!{S%nc1;TPZ$w*#1fAh67Py7+MgC*#py>+N={7E6VDqFPmU$pUl8B@=q=t5 zOPqMwi1T!r~0w0(kVbWH~#Z_J8GA%#$?f7WEg0we-V% tut08Iz9nBUTjt2TU~f;OUmZv~uSVSC|4Nf&Z%?J~>ytY1AMpkE{{bjpkS+iK diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index a3894ce..586beca 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -6,8 +6,8 @@ bounding-box size regression. Components ---------- -- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone - (``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to +- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone + (``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to produce a ``ModelOutput`` tuple from an input spectrogram. - ``build_detector`` – factory function that builds a ready-to-use ``Detector`` from a backbone configuration and a target class count. @@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by """ import torch -from loguru import logger -from batdetect2.models.backbones import ( - BackboneConfig, - UNetBackboneConfig, - build_backbone, -) +from batdetect2.models.backbones import BackboneConfig, build_backbone from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput +from batdetect2.models.types import ( + BackboneProtocol, + ClassifierHeadProtocol, + DetectorProtocol, + ModelOutput, + SizeHeadProtocol, +) __all__ = [ "Detector", @@ -34,7 +35,7 @@ __all__ = [ ] -class Detector(DetectionModel): +class Detector(torch.nn.Module): """Complete BatDetect2 detection and classification model. Combines a backbone feature extractor with two prediction heads: @@ -51,7 +52,7 @@ class Detector(DetectionModel): Attributes ---------- - backbone : BackboneModel + backbone : BackboneProtocol The feature extraction backbone. num_classes : int Number of target classes (inferred from the classifier head). @@ -61,13 +62,13 @@ class Detector(DetectionModel): Produces duration and bandwidth predictions from backbone features. """ - backbone: BackboneModel + backbone: BackboneProtocol def __init__( self, - backbone: BackboneModel, - classifier_head: ClassifierHead, - bbox_head: BBoxHead, + backbone: BackboneProtocol, + classifier_head: ClassifierHeadProtocol, + size_head: SizeHeadProtocol, ): """Initialise the Detector model. @@ -76,7 +77,7 @@ class Detector(DetectionModel): Parameters ---------- - backbone : BackboneModel + backbone : BackboneProtocol An initialised backbone module (e.g. built by ``build_backbone``). classifier_head : ClassifierHead @@ -90,7 +91,7 @@ class Detector(DetectionModel): self.backbone = backbone self.num_classes = classifier_head.num_classes self.classifier_head = classifier_head - self.bbox_head = bbox_head + self.size_head = size_head def forward(self, spec: torch.Tensor) -> ModelOutput: """Run the complete detection model on an input spectrogram. @@ -125,7 +126,7 @@ class Detector(DetectionModel): features = self.backbone(spec) classification = self.classifier_head(features) detection = classification.sum(dim=1, keepdim=True) - size_preds = self.bbox_head(features) + size_preds = self.size_head(features) return ModelOutput( detection_probs=detection, size_preds=size_preds, @@ -135,11 +136,11 @@ class Detector(DetectionModel): def build_detector( - num_classes: int, - num_sizes: int = 2, + class_names: list[str], + dimension_names: list[str], config: BackboneConfig | None = None, - backbone: BackboneModel | None = None, -) -> DetectionModel: + backbone: BackboneProtocol | None = None, +) -> DetectorProtocol: """Build a complete BatDetect2 detection model. Constructs a backbone from ``config``, attaches a ``ClassifierHead`` @@ -158,7 +159,7 @@ def build_detector( Returns ------- - DetectionModel + DetectorProtocol An initialised ``Detector`` instance ready for training or inference. @@ -168,24 +169,18 @@ def build_detector( If ``num_classes`` is not positive, or if the backbone configuration is invalid. """ - if backbone is None: - config = config or UNetBackboneConfig() - logger.opt(lazy=True).debug( - "Building model with config: \n{}", - lambda: config.to_yaml_string(), # type: ignore - ) - backbone = build_backbone(config=config) + backbone = backbone or build_backbone(config=config) classifier_head = ClassifierHead( - num_classes=num_classes, + class_names=class_names, in_channels=backbone.out_channels, ) bbox_head = BBoxHead( in_channels=backbone.out_channels, - num_sizes=num_sizes, + dimension_names=dimension_names, ) return Detector( backbone=backbone, classifier_head=classifier_head, - bbox_head=bbox_head, + size_head=bbox_head, ) diff --git a/src/batdetect2/models/heads.py b/src/batdetect2/models/heads.py index ba7b437..250ddb6 100644 --- a/src/batdetect2/models/heads.py +++ b/src/batdetect2/models/heads.py @@ -54,12 +54,14 @@ class ClassifierHead(nn.Module): 1×1 convolution with ``num_classes + 1`` output channels. """ - def __init__(self, num_classes: int, in_channels: int): + def __init__(self, class_names: list[str], in_channels: int): """Initialise the ClassifierHead.""" super().__init__() - self.num_classes = num_classes + self.class_names = class_names + self.num_classes = len(class_names) self.in_channels = in_channels + self.classifier = nn.Conv2d( self.in_channels, self.num_classes + 1, @@ -165,11 +167,12 @@ class BBoxHead(nn.Module): 1×1 convolution with 2 output channels (duration, bandwidth). """ - def __init__(self, in_channels: int, num_sizes: int = 2): + def __init__(self, dimension_names: list[str], in_channels: int): """Initialise the BBoxHead.""" super().__init__() self.in_channels = in_channels - self.num_sizes = num_sizes + self.dimension_names = dimension_names + self.num_sizes = len(dimension_names) self.bbox = nn.Conv2d( in_channels=self.in_channels, diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index a443743..be1c165 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig): monitor: str | None = None mode: str = "max" save_top_k: int = 1 + # Save distributable inference checkpoints by default. save_weights_only: bool = True filename: str | None = None save_last: bool | Literal["link"] = "link" diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index a4b2758..9f7a109 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( value, ) - for key, value in source_detector.bbox_head.state_dict().items(): - assert key in detector.bbox_head.state_dict() + for key, value in source_detector.size_head.state_dict().items(): + assert key in detector.size_head.state_dict() torch.testing.assert_close( - detector.bbox_head.state_dict()[key], + detector.size_head.state_dict()[key], value, ) diff --git a/tests/test_api_v2/test_finetune.py b/tests/test_api_v2/test_finetune.py index 8d8c6a2..5d8b223 100644 --- a/tests/test_api_v2/test_finetune.py +++ b/tests/test_api_v2/test_finetune.py @@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads( api = BatDetect2API.from_config() source_classifier_head = api.model.detector.classifier_head - source_bbox_head = api.model.detector.bbox_head + source_size_head = api.model.detector.size_head source_backbone = api.model.detector.backbone finetune_dir = tmp_path / "heads_only" @@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads( backbone_params = list(detector.backbone.parameters()) classifier_params = list(detector.classifier_head.parameters()) - bbox_params = list(detector.bbox_head.parameters()) + bbox_params = list(detector.size_head.parameters()) assert backbone_params assert classifier_params @@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads( assert finetuned_api is not api assert detector.backbone is source_backbone assert detector.classifier_head is not source_classifier_head - assert detector.bbox_head is not source_bbox_head + assert detector.size_head is not source_size_head assert list(finetune_dir.rglob("*.ckpt")) diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index f5ce769..35d39a9 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from typing import cast from batdetect2.models import UNetBackbone from batdetect2.models.backbones import UNetBackboneConfig @@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor: def test_build_detector_default(): """Test building the default detector without a config.""" num_classes = 5 - model = build_detector(num_classes=num_classes) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + ) assert isinstance(model, Detector) assert model.num_classes == num_classes assert isinstance(model.classifier_head, ClassifierHead) - assert isinstance(model.bbox_head, BBoxHead) + assert isinstance(model.size_head, BBoxHead) def test_build_detector_custom_config(): @@ -32,13 +36,19 @@ def test_build_detector_custom_config(): num_classes = 3 config = UNetBackboneConfig(in_channels=2, input_height=128) - model = build_detector(num_classes=num_classes, config=config) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + config=config, + ) assert isinstance(model, Detector) assert model.backbone.input_height == 128 - assert isinstance(model.backbone.encoder, Encoder) - assert model.backbone.encoder.in_channels == 2 + backbone = cast(UNetBackbone, model.backbone) + + assert isinstance(backbone.encoder, Encoder) + assert backbone.encoder.in_channels == 2 def test_build_detector_custom_size_channels(): @@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels(): config = UNetBackboneConfig(in_channels=1, input_height=128) model = build_detector( - num_classes=num_classes, - num_sizes=num_sizes, + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=[f"size_{i}" for i in range(num_sizes)], config=config, ) @@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram): num_classes = 4 # Build model matching the dummy input shape config = UNetBackboneConfig(in_channels=1, input_height=256) - model = build_detector(num_classes=num_classes, config=config) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + config=config, + ) # Process the spectrogram through the model # PyTorch expects shape (Batch, Channels, Height, Width) @@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor): config = UNetBackboneConfig( in_channels=spec.shape[1], input_height=spec.shape[2] ) - model = build_detector(num_classes=3, config=config) + model = build_detector( + class_names=["class_0", "class_1", "class_2"], + dimension_names=["width", "height"], + config=config, + ) # Process output = model(spec) diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 77a2856..4cffa48 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -8,7 +8,7 @@ from soundevent import data from batdetect2.train import TrainingConfig, run_train from batdetect2.train.checkpoints import ( - DEFAULT_BUNDLED_CHECKPOINT, + DEFAULT_CHECKPOINT, get_bundled_checkpoint_names, resolve_checkpoint_path, ) @@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged( def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: assert get_bundled_checkpoint_names() == ( - DEFAULT_BUNDLED_CHECKPOINT, + DEFAULT_CHECKPOINT, "batdetect2_uk_same", ) @@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None: resolved = resolve_checkpoint_path() - assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) + assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT) def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: - resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) + resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT) assert resolved.name == "batdetect2_uk_same.ckpt" assert resolved.exists() @@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None: def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: with pytest.raises( FileNotFoundError, - match="bundled checkpoint alias", + match="checkpoint alias", ): resolve_checkpoint_path("missing.ckpt") diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 9ab9e02..756329d 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> ( assert ( rebuilt_detector.classifier_head is not source_detector.classifier_head ) - assert rebuilt_detector.bbox_head is not source_detector.bbox_head + assert rebuilt_detector.size_head is not source_detector.size_head assert rebuilt_model.class_names == ["single_class"] assert rebuilt_model.dimension_names == ["width", "height"]