# -*- coding: utf-8 -*-
from __future__ import absolute_import
import re
[docs]class IobEncoder(object):
"""
Utility class for encoding tagged token streams using IOB2 encoding.
Encode input tokens using ``encode`` method::
>>> iob_encoder = IobEncoder()
>>> input_tokens = ["__START_PER__", "John", "__END_PER__", "said"]
>>> iob_encoder.encode(input_tokens)
[('John', 'B-PER'), ('said', 'O')]
Get the result in another format using ``encode_split`` method::
>>> input_tokens = ["hello", "__START_PER__", "John", "Doe", "__END_PER__", "__START_PER__", "Mary", "__END_PER__", "said"]
>>> tokens, tags = iob_encoder.encode_split(input_tokens)
>>> tokens, tags
(['hello', 'John', 'Doe', 'Mary', 'said'], ['O', 'B-PER', 'I-PER', 'B-PER', 'O'])
Note that IobEncoder is stateful. This means you can encode incomplete
stream and continue the encoding later::
>>> iob_encoder = IobEncoder()
>>> iob_encoder.encode(["__START_PER__", "John"])
[('John', 'B-PER')]
>>> iob_encoder.encode(["Mayer", "__END_PER__", "said"])
[('Mayer', 'I-PER'), ('said', 'O')]
To reset internal state, use ``reset method``::
>>> iob_encoder.reset()
Group results to entities::
>>> iob_encoder.group(iob_encoder.encode(input_tokens))
[(['hello'], 'O'), (['John', 'Doe'], 'PER'), (['Mary'], 'PER'), (['said'], 'O')]
Input token stream is processed by ``InputTokenProcessor()`` by default;
you can pass other token processing class to customize which tokens
are considered start/end tags.
"""
def __init__(self, token_processor=None):
self.token_processor = token_processor or InputTokenProcessor()
self.reset()
[docs] def reset(self):
""" Reset the sequence """
self.tag = 'O'
[docs] def iter_encode(self, input_tokens):
for token in input_tokens:
token_type, value = self.token_processor.classify(token)
if token_type == 'start':
self.tag = "B-" + value
elif token_type == 'end':
if value != self.tag[2:]:
raise ValueError(
"Invalid tag sequence: close tag '%s' "
"doesn't match open tag '%s'." % (value, self.tag)
)
self.tag = "O"
elif token_type == 'token':
yield token, self.tag
if self.tag[0] == 'B':
self.tag = "I" + self.tag[1:]
elif token_type == 'drop':
continue
else:
raise ValueError("Unknown token type '%s' for token '%s'" % (token_type, token))
[docs] def encode(self, input_tokens):
return list(self.iter_encode(input_tokens))
[docs] def encode_split(self, input_tokens):
""" The same as ``encode``, but returns ``(tokens, tags)`` tuple """
res = self.encode(input_tokens)
if not res:
return (), ()
tokens, tags = zip(*res)
return list(tokens), list(tags)
@classmethod
[docs] def group(cls, data, strict=False):
"""
Group IOB2-encoded entities. ``data`` should be an iterable
of ``(info, iob_tag)`` tuples. ``info`` could be any Python object,
``iob_tag`` should be a string with a tag.
Example::
>>>
>>> data = [("hello", "O"), (",", "O"), ("John", "B-PER"),
... ("Doe", "I-PER"), ("Mary", "B-PER"), ("said", "O")]
>>> for items, tag in IobEncoder.iter_group(data):
... print("%s %s" % (items, tag))
['hello', ','] O
['John', 'Doe'] PER
['Mary'] PER
['said'] O
By default, invalid sequences are fixed::
>>> data = [("hello", "O"), ("John", "I-PER"), ("Doe", "I-PER")]
>>> for items, tag in IobEncoder.iter_group(data):
... print("%s %s" % (items, tag))
['hello'] O
['John', 'Doe'] PER
Pass 'strict=True' argument to raise an exception for
invalid sequences::
>>> for items, tag in IobEncoder.iter_group(data, strict=True):
... print("%s %s" % (items, tag))
Traceback (most recent call last):
...
ValueError: Invalid sequence: I-PER tag can't start sequence
"""
return list(cls.iter_group(data, strict))
@classmethod
[docs] def iter_group(cls, data, strict=False):
buf, tag = [], 'O'
for info, iob_tag in data:
if iob_tag.startswith('I-') and tag != iob_tag[2:]:
if strict:
raise ValueError("Invalid sequence: %s tag can't start sequence" % iob_tag)
else:
iob_tag = 'B-' + iob_tag[2:] # fix bad tag
if iob_tag.startswith('B-'):
if buf:
yield buf, tag
buf = []
elif iob_tag == 'O':
if buf and tag != 'O':
yield buf, tag
buf = []
tag = 'O' if iob_tag == 'O' else iob_tag[2:]
buf.append(info)
if buf:
yield buf, tag
# FIXME: this hook is incomplete: __START_TAG__ tokens are assumed everywhere.