Source code for webstruct.sequence_encoding

# -*- coding: utf-8 -*-
from __future__ import absolute_import
import re


[docs]class IobEncoder(object): """ Utility class for encoding tagged token streams using IOB2 encoding. Encode input tokens using ``encode`` method:: >>> iob_encoder = IobEncoder() >>> input_tokens = ["__START_PER__", "John", "__END_PER__", "said"] >>> iob_encoder.encode(input_tokens) [('John', 'B-PER'), ('said', 'O')] Get the result in another format using ``encode_split`` method:: >>> input_tokens = ["hello", "__START_PER__", "John", "Doe", "__END_PER__", "__START_PER__", "Mary", "__END_PER__", "said"] >>> tokens, tags = iob_encoder.encode_split(input_tokens) >>> tokens, tags (['hello', 'John', 'Doe', 'Mary', 'said'], ['O', 'B-PER', 'I-PER', 'B-PER', 'O']) Note that IobEncoder is stateful. This means you can encode incomplete stream and continue the encoding later:: >>> iob_encoder = IobEncoder() >>> iob_encoder.encode(["__START_PER__", "John"]) [('John', 'B-PER')] >>> iob_encoder.encode(["Mayer", "__END_PER__", "said"]) [('Mayer', 'I-PER'), ('said', 'O')] To reset internal state, use ``reset method``:: >>> iob_encoder.reset() Group results to entities:: >>> iob_encoder.group(iob_encoder.encode(input_tokens)) [(['hello'], 'O'), (['John', 'Doe'], 'PER'), (['Mary'], 'PER'), (['said'], 'O')] Input token stream is processed by ``InputTokenProcessor()`` by default; you can pass other token processing class to customize which tokens are considered start/end tags. """ def __init__(self, token_processor=None): self.token_processor = token_processor or InputTokenProcessor() self.reset()
[docs] def reset(self): """ Reset the sequence """ self.tag = 'O'
[docs] def iter_encode(self, input_tokens): for token in input_tokens: token_type, value = self.token_processor.classify(token) if token_type == 'start': self.tag = "B-" + value elif token_type == 'end': if value != self.tag[2:]: raise ValueError( "Invalid tag sequence: close tag '%s' " "doesn't match open tag '%s'." % (value, self.tag) ) self.tag = "O" elif token_type == 'token': yield token, self.tag if self.tag[0] == 'B': self.tag = "I" + self.tag[1:] elif token_type == 'drop': continue else: raise ValueError("Unknown token type '%s' for token '%s'" % (token_type, token))
[docs] def encode(self, input_tokens): return list(self.iter_encode(input_tokens))
[docs] def encode_split(self, input_tokens): """ The same as ``encode``, but returns ``(tokens, tags)`` tuple """ res = self.encode(input_tokens) if not res: return (), () tokens, tags = zip(*res) return list(tokens), list(tags)
@classmethod
[docs] def group(cls, data, strict=False): """ Group IOB2-encoded entities. ``data`` should be an iterable of ``(info, iob_tag)`` tuples. ``info`` could be any Python object, ``iob_tag`` should be a string with a tag. Example:: >>> >>> data = [("hello", "O"), (",", "O"), ("John", "B-PER"), ... ("Doe", "I-PER"), ("Mary", "B-PER"), ("said", "O")] >>> for items, tag in IobEncoder.iter_group(data): ... print("%s %s" % (items, tag)) ['hello', ','] O ['John', 'Doe'] PER ['Mary'] PER ['said'] O By default, invalid sequences are fixed:: >>> data = [("hello", "O"), ("John", "I-PER"), ("Doe", "I-PER")] >>> for items, tag in IobEncoder.iter_group(data): ... print("%s %s" % (items, tag)) ['hello'] O ['John', 'Doe'] PER Pass 'strict=True' argument to raise an exception for invalid sequences:: >>> for items, tag in IobEncoder.iter_group(data, strict=True): ... print("%s %s" % (items, tag)) Traceback (most recent call last): ... ValueError: Invalid sequence: I-PER tag can't start sequence """ return list(cls.iter_group(data, strict))
@classmethod
[docs] def iter_group(cls, data, strict=False): buf, tag = [], 'O' for info, iob_tag in data: if iob_tag.startswith('I-') and tag != iob_tag[2:]: if strict: raise ValueError("Invalid sequence: %s tag can't start sequence" % iob_tag) else: iob_tag = 'B-' + iob_tag[2:] # fix bad tag if iob_tag.startswith('B-'): if buf: yield buf, tag buf = [] elif iob_tag == 'O': if buf and tag != 'O': yield buf, tag buf = [] tag = 'O' if iob_tag == 'O' else iob_tag[2:] buf.append(info) if buf: yield buf, tag
# FIXME: this hook is incomplete: __START_TAG__ tokens are assumed everywhere.
[docs]class InputTokenProcessor(object): def __init__(self, tagset=None): if tagset is not None: tag_re = '|'.join(tagset) else: tag_re = '\w+?' self.tag_re = re.compile('__(START|END)_(%s)__' % tag_re)
[docs] def classify(self, token): """ >>> tp = InputTokenProcessor() >>> tp.classify('foo') ('token', 'foo') >>> tp.classify('__START_ORG__') ('start', 'ORG') >>> tp.classify('__END_ORG__') ('end', 'ORG') """ # start/end tags m = self.tag_re.match(token) if m: return m.group(1).lower(), m.group(2) # # drop standalone commas and semicolons by default? # if token in {',', ';'}: # return 'drop', token # regular token return 'token', token