| import torch |
| import nltk |
| from nltk import pos_tag |
| from nltk.tokenize import word_tokenize |
| from nltk.corpus import wordnet |
| from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel |
| from torch import nn |
| from itertools import chain |
| from torch.nn import MSELoss, CrossEntropyLoss |
| from cleantext import clean |
| from num2words import num2words |
| import re |
| import string |
| import inflect |
|
|
| nltk.download('punkt') |
| nltk.download('punkt_tab') |
| nltk.download('averaged_perceptron_tagger') |
| nltk.download('averaged_perceptron_tagger_eng') |
| nltk.download('wordnet') |
| |
|
|
| punct_chars = list((set(string.punctuation) | {'’', '‘', '–', '—', '~', '|', '“', '”', '…', "'", "`", '_'})) |
| punct_chars.sort() |
| punctuation = ''.join(punct_chars) |
| replace = re.compile('[%s]' % re.escape(punctuation)) |
|
|
| MATH_PREFIXES = [ |
| "sum", |
| "arc", |
| "mass", |
| "digit", |
| "graph", |
| "liter", |
| "gram", |
| "add", |
| "angle", |
| "scale", |
| "data", |
| "array", |
| "ruler", |
| "meter", |
| "total", |
| "unit", |
| "prism", |
| "median", |
| "ratio", |
| "area", |
|
|
| |
| "multipl", |
| "divid", |
| "subtrac", |
| "logarit", |
| "algebr", |
| "calcul", |
| "matri", |
| "vect", |
| "geometr", |
| "statist", |
| "probabli", |
| "coeffi", |
| "measure", |
| "simplif" |
| ] |
|
|
| MATH_WORDS = [ |
| "absolute deviation", |
| "absolute value", |
| "abundant number", |
| "accurate", |
| "acre", |
| "acute", |
| "add", |
| "addend", |
| "addition fact", |
| "addition", |
| "additive identity", |
| "additive inverse", |
| "adjacent", |
| "algebra", |
| "algebraic", |
| "algorithm", |
| "alternate interior angle", |
| "altitude", |
| "analog", |
| "angle measure", |
| "angle", |
| "angular", |
| "apex", |
| "approximate", |
| "arc", |
| "area model", |
| "area", |
| "arithmetic fact", |
| "arithmetic", |
| "array", |
| "associative property", |
| "associative", |
| "astronomical unit", |
| "attribute", |
| "average", |
| "axis", |
| "bar graph", |
| "base of a parallelogram", |
| "base of a prism", |
| "base of a pyramid", |
| "base of a triangle", |
| "base of an exponent", |
| "base of", |
| "base ten", |
| "base", |
| "baseline", |
| "benchmark fraction", |
| "billion", |
| "binomial", |
| "bisect", |
| "bisector", |
| "box and whisker plot", |
| "box plot", |
| "capacity", |
| "cartesian coordinate", |
| "categorical data", |
| "categorical", |
| "celsius", |
| "census", |
| "cent", |
| "center of a circle", |
| "center of a dilation", |
| "center of a sphere", |
| "center", |
| "centimeter", |
| "central angle", |
| "centroid", |
| "chance experiment", |
| "chance", |
| "chord", |
| "circle graph", |
| "circle", |
| "circular", |
| "circumference", |
| "clockwise", |
| "coefficient", |
| "collinear", |
| "column matrix" |
| "column", |
| "combination", |
| "combine", |
| "common denominator", |
| "common factor", |
| "common fraction", |
| "common multiple", |
| "commutative property", |
| "commutative", |
| "comparison diagram", |
| "comparison story", |
| "compass", |
| "complement", |
| "complementary", |
| "compose", |
| "composite", |
| "concave polygon", |
| "concentric circles", |
| "concentric", |
| "cone", |
| "congruent", |
| "consecutive", |
| "constant function", |
| "constant", |
| "continuous model of area", |
| "continuous model of volume", |
| "continuous", |
| "contour", |
| "conversion fact", |
| "conversion factor", |
| "convert", |
| "convex function", |
| "convex polygon", |
| "coordinate", |
| "coplanar", |
| "corresponding", |
| "counterclockwise", |
| "counting numbers", |
| "counting up subtraction", |
| "covariance", |
| "covariate", |
| "cover-up method", |
| "cross multiplication", |
| "cross product", |
| "cross section", |
| "cross-section", |
| "cube root", |
| "cube", |
| "cubed", |
| "cubic unit", |
| "cubic", |
| "cubit", |
| "cup", |
| "curved surface", |
| "customary system of measurement", |
| "customary unit", |
| "cylinder", |
| "cylindrical", |
| "data", |
| "decagon", |
| "decimal divisor", |
| "decimal expanded form", |
| "decimal fraction", |
| "decimal point", |
| "decimal", |
| "decimeter", |
| "decompose", |
| "deficient number", |
| "degree", |
| "delta", |
| "denominator", |
| "density", |
| "dependent event", |
| "dependent variable", |
| "deposit", |
| "derivative", |
| "determinant", |
| "diagonal", |
| "diameter", |
| "difference", |
| "differential" |
| "digit", |
| "digital", |
| "dilation", |
| "dimension", |
| "discrete model", |
| "displacement method", |
| "distance", |
| "distribution", |
| "distributive", |
| "divide", |
| "divided", |
| "divides", |
| "dividing", |
| "dividend", |
| "divisibility test", |
| "divisible by", |
| "divisible", |
| "division", |
| "divisor", |
| "dodecahedron", |
| "dot plot", |
| "double number line diagram", |
| "double stem plot", |
| "doubles fact", |
| "edge", |
| "egyptian multiplication", |
| "elevation", |
| "embed figure", |
| "end point", |
| "endpoint", |
| "enlarge", |
| "equal group", |
| "equal part", |
| "equal", |
| "equality", |
| "equation", |
| "equidistant mark", |
| "equilateral polygon", |
| "equilateral triangle", |
| "equilateral", |
| "equivalence", |
| "equivalent expression", |
| "equivalent fraction", |
| "equivalent", |
| "error bound", |
| "error of measurement", |
| "estimat", |
| "estimate", |
| "european subtraction", |
| "even number", |
| "event", |
| "expand", |
| "expanded form", |
| "expanded notation", |
| "expected outcome", |
| "expected value", |
| "exponent", |
| "exponential function", |
| "exponential growth", |
| "expression", |
| "extended fact", |
| "face", |
| "fact power", |
| "fact triangle", |
| "factor", |
| "factored", |
| "factoring", |
| "factors", |
| "factorial", |
| "factors of number", |
| "fahrenheit", |
| "false number sentence", |
| "figurate number", |
| "flowchart", |
| "fluid ounce", |
| "formula", |
| "fraction form", |
| "fraction", |
| "fractional part", |
| "fractional unit", |
| "frequency", |
| "fulcrum", |
| "function machine", |
| "function", |
| "furlong", |
| "gallon", |
| "gcd", |
| "genus", |
| "geoboard", |
| "geometr", |
| "geometric solid", |
| "geometry template", |
| "girth", |
| "golden ratio", |
| "golden rectangle", |
| "gram", |
| "graph key", |
| "graph", |
| "greatest common divisor" |
| "greatest common factor", |
| "grouping symbol", |
| "half circle", |
| "half-circle", |
| "hashmark", |
| "height of a parallelogram or triangle", |
| "height of", |
| "height", |
| "hemisphere", |
| "heptagon", |
| "heptagonal", |
| "hexagon", |
| "hexagonal", |
| "hierarchy", |
| "histogram", |
| "horizontal shift", |
| "horizontal stretch", |
| "horizontal", |
| "hundred", |
| "hundredth", |
| "hypotenuse", |
| "hypothesis", |
| "icosahedron", |
| "identity function", |
| "identity matrix", |
| "identity property of", |
| "identity property", |
| "improper fraction", |
| "inch", |
| "incircle", |
| "indefinite integral", |
| "independent event", |
| "independent variable", |
| "index of location", |
| "indirect measurement", |
| "inequality", |
| "infinity", |
| "input", |
| "inscribed angle", |
| "inscribed polygon", |
| "instance of a pattern", |
| "integer", |
| "intercept", |
| "intercepted arc", |
| "interior angle", |
| "interior of a figure", |
| "interpolate", |
| "interquartile range", |
| "intersect", |
| "interval", |
| "inverse operation", |
| "inverse", |
| "iqr", |
| "irrational number", |
| "irrational root", |
| "irrational", |
| "isometry transformation", |
| "isosceles trapezoid", |
| "isosceles triangle", |
| "isosceles", |
| "joint probability", |
| "joint variation", |
| "juxtapose", |
| "key sequence", |
| "kilogram", |
| "kilometer", |
| "kite", |
| "label", |
| "landmark", |
| "latitude", |
| "lattice multiplication", |
| "lcm", |
| "least common denominator", |
| "least common multiple", |
| "left to right subtraction", |
| "leg of a right triangle", |
| "legs", |
| "length", |
| "like fraction", |
| "like terms", |
| "line graph", |
| "line of reflection", |
| "line of symmetry", |
| "line plot", |
| "line segment", |
| "line symmetry", |
| "line", |
| "linear relationship", |
| "lines of latitude", |
| "lines of longitude", |
| "liter", |
| "local maximum", |
| "local minimum", |
| "locus", |
| "logarithm", |
| "logarithmic function", |
| "logarithmic scale", |
| "logic", |
| "long division", |
| "longitude", |
| "lowest term", |
| "magnitude estimate", |
| "make ten", |
| "map legend", |
| "map scale", |
| "mass", |
| "maximum", |
| "mean absolute deviation", |
| "mean value", |
| "mean", |
| "measure of center", |
| "measure", |
| "measurement division", |
| "measurement error", |
| "measurement unit", |
| "median", |
| "meridian bar", |
| "meter", |
| "meters per second", |
| "metric system", |
| "metric unit", |
| "metric", |
| "midpoint", |
| "mile", |
| "milliliter", |
| "millimeter", |
| "millisecond", |
| "minimum", |
| "minuend", |
| "mirror image", |
| "mixed number", |
| "mixed unit", |
| "mobius", |
| "modal", |
| "mode", |
| "multipl", |
| "multiply", |
| "multiplied", |
| "multiplies", |
| "multiple", |
| "multiplication", |
| "multiplying", |
| "multiplication counting principle", |
| "multiplication diagram", |
| "multiplication fact", |
| "multiplication symbol", |
| "multiplication use class", |
| "multiplicative identity", |
| "multiplicative inverse", |
| "multiplier", |
| "mutually exclusive event", |
| "natural number", |
| "negative association", |
| "negative exponent", |
| "negative number", |
| "negative rational number", |
| "nested parentheses", |
| "net score", |
| "net weight", |
| "net", |
| "nonagon", |
| "nonconvex polygon", |
| "nonlinear", |
| "normal distribution", |
| "normal span", |
| "normal", |
| "number bond", |
| "number disk", |
| "number grid", |
| "number line", |
| "number path", |
| "number sentence", |
| "number sequence", |
| "numeral", |
| "numeration", |
| "numerator", |
| "numerical data", |
| "numerical", |
| "obtuse", |
| "octagon", |
| "octagonal", |
| "octahedron", |
| "odd number", |
| "open proportion", |
| "operation symbol", |
| "operational", |
| "opposite angle", |
| "opposite change rule", |
| "opposite of a number", |
| "opposite side", |
| "opposite vertex", |
| "opposite", |
| "order of magnitude", |
| "order of operations", |
| "order of rotation symmetry", |
| "order of", |
| "ordered pair", |
| "ordered", |
| "ordinal number", |
| "orthogonal", |
| "ounce", |
| "outlier", |
| "pace", |
| "pan balance", |
| "parabola", |
| "parallel lines", |
| "parallel plane", |
| "parallel", |
| "parallelogram", |
| "parentheses", |
| "part to part ratio", |
| "part to whole ratio", |
| "part whole fraction", |
| "partial differences subtraction", |
| "partial product", |
| "partial products multiplication", |
| "partial quotients division", |
| "partial sums addition", |
| "partition", |
| "partitive division", |
| "parts and total diagram", |
| "pentagon", |
| "pentagonal", |
| "per capita", |
| "per unit rate", |
| "per", |
| "percent circle", |
| "percent", |
| "percentage", |
| "perfect number", |
| "perfect square", |
| "perfect triangle", |
| "perimeter", |
| "permutation", |
| "perpendicular", |
| "perpetual calendar", |
| "pi", |
| "picture graph", |
| "pie graph", |
| "pint", |
| "pivot", |
| "place value", |
| "plane figure", |
| "plane", |
| "point symmetry", |
| "point", |
| "polar coordinate", |
| "polygon", |
| "polyhedron", |
| "polynominal" |
| "population density", |
| "population", |
| "positive association", |
| "positive number", |
| "pound", |
| "power", |
| "precise", |
| "predict", |
| "prediction line", |
| "preimage", |
| "prime factor", |
| "prime factorization", |
| "prime meridian", |
| "prime number", |
| "prism", |
| "probability meter", |
| "probability tree diagram", |
| "probability", |
| "product", |
| "proper factor", |
| "proper fraction", |
| "property", |
| "proportion", |
| "proportional", |
| "proportionality", |
| "protractor", |
| "pyramid", |
| "pythagorean theorem", |
| "quadrangle", |
| "quadrant", |
| "quadratic", |
| "quadrilateral", |
| "quart", |
| "quarter circle", |
| "quarter of", |
| "quarter-circle", |
| "quartile", |
| "quick common denominator", |
| "quotient", |
| "quotitive division", |
| "radian", |
| "radius of" |
| "radius", |
| "random draw", |
| "random experiment", |
| "random number", |
| "random sample", |
| "random", |
| "range", |
| "rank", |
| "rate diagram", |
| "rate multiplication ", |
| "rate of change", |
| "rate unit", |
| "rate", |
| "ratio of", |
| "ratio", |
| "rational equation", |
| "rational number", |
| "ray", |
| "real number", |
| "recall survey", |
| "reciprocal", |
| "rectang", |
| "rectangle", |
| "rectangular array", |
| "rectangular coordinate grid", |
| "rectangular prism", |
| "rectangular pyramid", |
| "rectangular", |
| "rectilinear figure", |
| "reflection", |
| "reflex angle", |
| "region", |
| "regular polygon", |
| "regular polyhedron", |
| "regular tessellation", |
| "relation symbol", |
| "relative frequency", |
| "remainder", |
| "repeated addition", |
| "repeating decimal", |
| "representative", |
| "revolution", |
| "rhombus", |
| "right angle", |
| "right cone", |
| "right cylinder", |
| "right prism", |
| "right pyramid", |
| "right triangle", |
| "rigid transformation", |
| "roman numerals", |
| "root", |
| "rotate", |
| "rotation symmetry", |
| "rotation", |
| "round off", |
| "round-off", |
| "ruler", |
| "same change rule for subtraction", |
| "sample", |
| "scalar", |
| "scale factor", |
| "scale model", |
| "scale of a map", |
| "scale of a number line", |
| "scale", |
| "scaled graph", |
| "scaled", |
| "scalene triangle", |
| "scalene", |
| "scatter plot", |
| "scattergram", |
| "sector", |
| "segment", |
| "semi-circle", |
| "semicircle", |
| "sequence", |
| "set", |
| "sign", |
| "significant digit", |
| "significant figure", |
| "similar figures", |
| "similar", |
| "simpler form", |
| "simplify", |
| "simulation", |
| "situtation diagram", |
| "skew line", |
| "slanted", |
| "slide rule", |
| "slope", |
| "solid figure", |
| "solution", |
| "span", |
| "speed", |
| "sphere", |
| "square root", |
| "square unit", |
| "square", |
| "squared", |
| "stacked bar graph", |
| "standard form", |
| "standard unit", |
| "statistic", |
| "stem and leaf plot", |
| "step graph", |
| "straight angle", |
| "straightedge", |
| "subset of" |
| "substitute", |
| "subtract", |
| "subtrahend", |
| "sum of", |
| "sum", |
| "supplementary angle", |
| "surface area", |
| "surface", |
| "survey", |
| "symmetric", |
| "symmetry", |
| "system of equation", |
| "system of", |
| "table", |
| "take from ten", |
| "tally", |
| "tangent circle", |
| "tangent", |
| "tangram", |
| "tape diagram", |
| "temperature", |
| "template", |
| "tens place", |
| "tenth", |
| "term", |
| "terminating decimal", |
| "tessellat", |
| "tessellate", |
| "tessellation", |
| "tetrahedron", |
| "tetromino", |
| "theorem", |
| "thermometer", |
| "thousand", |
| "thousandth", |
| "tile", |
| "tiling", |
| "time graph", |
| "timeline", |
| "top heavy fraction", |
| "topological", |
| "topology", |
| "total area", |
| "total of", |
| "total surface", |
| "total volume", |
| "trade first subtraction", |
| "transformation", |
| "translation", |
| "transversal", |
| "trapezoid", |
| "tree diagram", |
| "triangle", |
| "triangular", |
| "true number sentence", |
| "truncate", |
| "twin prime", |
| "two-way table", |
| "unit cube", |
| "unit form", |
| "unit fraction", |
| "unit interval", |
| "unit price", |
| "unit rate", |
| "unit square", |
| "unit", |
| "unknown", |
| "unlike denominator", |
| "unlike fraction", |
| "value", |
| "vanishing ", |
| "variability", |
| "variable", |
| "velocity", |
| "venn diagram", |
| "vernal equinox", |
| "vertex", |
| "vertical", |
| "volume of", |
| "volume", |
| "weight", |
| "whole number", |
| "whole unit", |
| "whole", |
| "width", |
| "withdrawal", |
| "word form", |
| "x axes", |
| "x axis", |
| "x intercept", |
| "x-axes", |
| "x-axis", |
| "y axes", |
| "y axis", |
| "y intercept", |
| "y-axes", |
| "y-axis", |
| "y-intercept", |
| "yard", |
| "zero property of multiplication", |
| "zero", |
| ] |
|
|
| PLURAL_TO_SINGULAR_EXCLUSIONS = [ |
| "axis", |
| "continuous", |
| "data", |
| "minus", |
| "miss", |
| "plus", |
| "yes", |
| ] |
|
|
| p = inflect.engine() |
|
|
| def is_plural_regex(word): |
| """Detect if a word is plural using common pluralization rules.""" |
| |
| return re.search(r'(s$|es$|ies$)', word.lower()) and not re.search(r'(ss$)', word.lower()) |
|
|
| def is_plural_wordnet(word): |
| |
| singular_synsets = wordnet.synsets(word, pos=wordnet.NOUN) |
| plural_synsets = wordnet.synsets(word.rstrip('s'), pos=wordnet.NOUN) |
| return len(plural_synsets) > len(singular_synsets) |
|
|
| def is_plural_pos(word): |
| """Determine if a word is plural using NLTK's part-of-speech tagging.""" |
| |
| tokens = word_tokenize(word) |
| |
| pos = pos_tag(tokens)[0][1] |
| |
| return pos in ["NNS", "NNPS"] |
|
|
| def is_plural(word): |
| """Check if a word is plural.""" |
| if word in PLURAL_TO_SINGULAR_EXCLUSIONS: |
| return False |
| return is_plural_regex(word) or is_plural_pos(word) or is_plural_wordnet(word) |
|
|
| def singular_to_plural(word): |
| """Convert singular words to plural using inflect.""" |
| plural = p.plural(word) |
| return plural or word |
|
|
| def plural_to_singular(word): |
| """Convert plural word to singular using inflect.""" |
| if is_plural(word): |
| return p.singular_noun(word) or word |
| return word |
|
|
| plural_MATH_WORDS = [singular_to_plural(word) for word in MATH_WORDS] |
|
|
| MATH_WORDS += plural_MATH_WORDS |
|
|
| def get_num_words(text): |
| if not isinstance(text, str): |
| print("%s is not a string" % text) |
| text = replace.sub(' ', text) |
| text = re.sub(r'\s+', ' ', text) |
| text = text.strip() |
| text = re.sub(r'\[.+\]', " ", text) |
| return len(text.split()) |
|
|
| def number_to_words(num): |
| try: |
| return num2words(re.sub(",", "", num)) |
| except: |
| return num |
|
|
|
|
| clean_str = lambda s: clean(s, |
| fix_unicode=True, |
| to_ascii=True, |
| lower=True, |
| no_line_breaks=True, |
| no_urls=True, |
| no_emails=True, |
| no_phone_numbers=True, |
| no_numbers=True, |
| no_digits=False, |
| no_currency_symbols=False, |
| no_punct=False, |
| replace_with_url="<URL>", |
| replace_with_email="<EMAIL>", |
| replace_with_phone_number="<PHONE>", |
| replace_with_number=lambda m: number_to_words(m.group()), |
| replace_with_digit="0", |
| replace_with_currency_symbol="<CUR>", |
| lang="en" |
| ) |
|
|
| clean_str_nopunct = lambda s: clean(s, |
| fix_unicode=True, |
| to_ascii=True, |
| lower=True, |
| no_line_breaks=True, |
| no_urls=True, |
| no_emails=True, |
| no_phone_numbers=True, |
| no_numbers=True, |
| no_digits=False, |
| no_currency_symbols=False, |
| no_punct=True, |
| replace_with_url="<URL>", |
| replace_with_email="<EMAIL>", |
| replace_with_phone_number="<PHONE>", |
| replace_with_number=lambda m: number_to_words(m.group()), |
| replace_with_digit="0", |
| replace_with_currency_symbol="<CUR>", |
| lang="en" |
| ) |
|
|
|
|
|
|
| class MultiHeadModel(BertPreTrainedModel): |
| """Pre-trained BERT model that uses our loss functions""" |
|
|
| def __init__(self, config, head2size): |
| super(MultiHeadModel, self).__init__(config, head2size) |
| config.num_labels = 1 |
| self.bert = BertModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| module_dict = {} |
| for head_name, num_labels in head2size.items(): |
| module_dict[head_name] = nn.Linear(config.hidden_size, num_labels) |
| self.heads = nn.ModuleDict(module_dict) |
|
|
| self.init_weights() |
|
|
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, |
| head2labels=None, return_pooler_output=False, head2mask=None, |
| nsp_loss_weights=None): |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| output = self.bert( |
| input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, |
| output_attentions=False, output_hidden_states=False, return_dict=True) |
| pooled_output = self.dropout(output["pooler_output"]).to(device) |
|
|
| head2logits = {} |
| return_dict = {} |
| for head_name, head in self.heads.items(): |
| head2logits[head_name] = self.heads[head_name](pooled_output) |
| head2logits[head_name] = head2logits[head_name].float() |
| return_dict[head_name + "_logits"] = head2logits[head_name] |
|
|
|
|
| if head2labels is not None: |
| for head_name, labels in head2labels.items(): |
| num_classes = head2logits[head_name].shape[1] |
|
|
| |
| if num_classes == 1: |
|
|
| |
| if head2mask is not None and head_name in head2mask: |
| num_positives = head2labels[head2mask[head_name]].sum() |
| if num_positives == 0: |
| return_dict[head_name + "_loss"] = torch.tensor([0]).to(device) |
| else: |
| loss_fct = MSELoss(reduction='none') |
| loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
| return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives |
| else: |
| loss_fct = MSELoss() |
| return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
| else: |
| loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float()) |
| return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1)) |
|
|
|
|
| if return_pooler_output: |
| return_dict["pooler_output"] = output["pooler_output"] |
|
|
| return return_dict |
|
|
| class InputBuilder(object): |
| """Base class for building inputs from segments.""" |
|
|
| def __init__(self, tokenizer): |
| self.tokenizer = tokenizer |
| self.mask = [tokenizer.mask_token_id] |
|
|
| def build_inputs(self, history, reply, max_length): |
| raise NotImplementedError |
|
|
| def mask_seq(self, sequence, seq_id): |
| sequence[seq_id] = self.mask |
| return sequence |
|
|
| @classmethod |
| def _combine_sequence(self, history, reply, max_length, flipped=False): |
| |
| history = [s[:max_length] for s in history] |
| reply = reply[:max_length] |
| if flipped: |
| return [reply] + history |
| return history + [reply] |
|
|
|
|
| class BertInputBuilder(InputBuilder): |
| """Processor for BERT inputs""" |
|
|
| def __init__(self, tokenizer): |
| InputBuilder.__init__(self, tokenizer) |
| self.cls = [tokenizer.cls_token_id] |
| self.sep = [tokenizer.sep_token_id] |
| self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"] |
| self.padded_inputs = ["input_ids", "token_type_ids"] |
| self.flipped = False |
|
|
|
|
| def build_inputs(self, history, reply, max_length, input_str=True): |
| """See base class.""" |
| if input_str: |
| history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history] |
| reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply)) |
| sequence = self._combine_sequence(history, reply, max_length, self.flipped) |
| sequence = [s + self.sep for s in sequence] |
| sequence[0] = self.cls + sequence[0] |
|
|
| instance = {} |
| instance["input_ids"] = list(chain(*sequence)) |
| last_speaker = 0 |
| other_speaker = 1 |
| seq_length = len(sequence) |
| instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker |
| for i, s in enumerate(sequence) for _ in s] |
| return instance |