Skip to content

Commit e835e8e

Browse files
Add files via upload
1 parent c05dc21 commit e835e8e

4 files changed

Lines changed: 143 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .common import *
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Copyright 2017-2018 Fizyr (https://fizyr.com)
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import keras
18+
from ..utils.coco_eval import evaluate_coco
19+
20+
21+
class CocoEval(keras.callbacks.Callback):
22+
def __init__(self, generator, threshold=0.05):
23+
self.generator = generator
24+
self.threshold = threshold
25+
26+
super(CocoEval, self).__init__()
27+
28+
def on_epoch_end(self, epoch, logs={}):
29+
evaluate_coco(self.generator, self.model, self.threshold)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import keras.callbacks
2+
3+
4+
class RedirectModel(keras.callbacks.Callback):
5+
"""Callback which wraps another callback, but executed on a different model.
6+
# Arguments
7+
callback: callback to wrap.
8+
model: model to use when executing callbacks.
9+
# Example
10+
```python
11+
model = keras.models.load_model('model.h5')
12+
model_checkpoint = ModelCheckpoint(filepath='snapshot.h5')
13+
parallel_model = multi_gpu_model(model, gpus=2)
14+
parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)])
15+
```
16+
"""
17+
18+
def __init__(self,
19+
callback,
20+
model):
21+
super(RedirectModel, self).__init__()
22+
23+
self.callback = callback
24+
self.redirect_model = model
25+
26+
def on_epoch_begin(self, epoch, logs=None):
27+
self.callback.on_epoch_begin(epoch, logs=logs)
28+
29+
def on_epoch_end(self, epoch, logs=None):
30+
self.callback.on_epoch_end(epoch, logs=logs)
31+
32+
def on_batch_begin(self, batch, logs=None):
33+
self.callback.on_batch_begin(batch, logs=logs)
34+
35+
def on_batch_end(self, batch, logs=None):
36+
self.callback.on_batch_end(batch, logs=logs)
37+
38+
def on_train_begin(self, logs=None):
39+
# overwrite the model with our custom model
40+
self.callback.set_model(self.redirect_model)
41+
42+
self.callback.on_train_begin(logs=logs)
43+
44+
def on_train_end(self, logs=None):
45+
self.callback.on_train_end(logs=logs)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Copyright 2017-2018 Fizyr (https://fizyr.com)
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import keras
18+
from ..utils.eval import evaluate
19+
20+
21+
class Evaluate(keras.callbacks.Callback):
22+
def __init__(self, generator, iou_threshold=0.5, score_threshold=0.05, max_detections=100, save_path=None, tensorboard=None, verbose=1):
23+
""" Evaluate a given dataset using a given model at the end of every epoch during training.
24+
25+
# Arguments
26+
generator : The generator that represents the dataset to evaluate.
27+
iou_threshold : The threshold used to consider when a detection is positive or negative.
28+
score_threshold : The score confidence threshold to use for detections.
29+
max_detections : The maximum number of detections to use per image.
30+
save_path : The path to save images with visualized detections to.
31+
tensorboard : Instance of keras.callbacks.TensorBoard used to log the mAP value.
32+
verbose : Set the verbosity level, by default this is set to 1.
33+
"""
34+
self.generator = generator
35+
self.iou_threshold = iou_threshold
36+
self.score_threshold = score_threshold
37+
self.max_detections = max_detections
38+
self.save_path = save_path
39+
self.tensorboard = tensorboard
40+
self.verbose = verbose
41+
42+
super(Evaluate, self).__init__()
43+
44+
def on_epoch_end(self, epoch, logs={}):
45+
# run evaluation
46+
average_precisions = evaluate(
47+
self.generator,
48+
self.model,
49+
iou_threshold=self.iou_threshold,
50+
score_threshold=self.score_threshold,
51+
max_detections=self.max_detections,
52+
save_path=self.save_path
53+
)
54+
55+
self.mean_ap = sum(average_precisions.values()) / len(average_precisions)
56+
57+
if self.tensorboard is not None and self.tensorboard.writer is not None:
58+
import tensorflow as tf
59+
summary = tf.Summary()
60+
summary_value = summary.value.add()
61+
summary_value.simple_value = self.mean_ap
62+
summary_value.tag = "mAP"
63+
self.tensorboard.writer.add_summary(summary, epoch)
64+
65+
if self.verbose == 1:
66+
for label, average_precision in average_precisions.items():
67+
print(self.generator.label_to_name(label), '{:.4f}'.format(average_precision))
68+
print('mAP: {:.4f}'.format(self.mean_ap))

0 commit comments

Comments
 (0)