Skip to content

Commit 77f2b21

Browse files
committed
* Update conll_train script
1 parent fab5386 commit 77f2b21

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

bin/tagger/conll_train.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ def _parse_line(line):
8787

8888

8989
def score_model(nlp, gold_tuples, verbose=False):
90-
scorer = Scorer()
90+
correct = 0.0
91+
total = 0.0
9192
for words, gold_tags in gold_tuples:
9293
tokens = nlp.tokenizer.tokens_from_list(words)
9394
nlp.tagger(tokens)
9495
for token, gold in zip(tokens, gold_tags):
95-
scorer.tags.tp += token.tag_ == gold
96-
scorer.tags.fp += token.tag_ != gold
97-
scorer.tags.fn += token.tag_ != gold
98-
return scorer.tags_acc
96+
correct += token.tag_ == gold
97+
total += 1
98+
return (correct / total) * 100
9999

100100

101101
def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
@@ -116,8 +116,6 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
116116
random.shuffle(train_sents)
117117
heldout_sents = train_sents[:int(nr_train * 0.1)]
118118
train_sents = train_sents[len(heldout_sents):]
119-
#train_sents = train_sents[:500]
120-
#assert len(heldout_sents) < len(train_sents)
121119
prev_score = 0.0
122120
variance = 0.001
123121
last_good_learn_rate = nlp.tagger.model.eta
@@ -130,15 +128,15 @@ def train(Language, train_sents, dev_sents, model_dir, n_iter=15, seed=0,
130128
acc += nlp.tagger.train(tokens, gold_tags)
131129
total += len(tokens)
132130
n += 1
133-
if n and n % 10000 == 0:
131+
if n and n % 20000 == 0:
134132
dev_score = score_model(nlp, heldout_sents)
135133
eval_score = score_model(nlp, dev_sents)
136-
if dev_score > prev_score:
134+
if dev_score >= prev_score:
137135
nlp.tagger.model.keep_update()
138136
prev_score = dev_score
139137
variance = 0.001
140138
last_good_learn_rate = nlp.tagger.model.eta
141-
nlp.tagger.model.eta *= 1.05
139+
nlp.tagger.model.eta *= 1.01
142140
print('%d:\t%.3f\t%.3f\t%.3f\t%.4f' % (n, acc/total, dev_score, eval_score, nlp.tagger.model.eta))
143141
else:
144142
nlp.tagger.model.backtrack()

0 commit comments

Comments
 (0)