| """Load Data from a MediaWiki dump xml.""" |
| import ast |
| import glob |
| import pickle |
| import uuid |
| from typing import List, Optional |
| import os |
| import bz2 |
| import csv |
| import numpy as np |
| import pandas as pd |
| import pytest |
| from matplotlib import pyplot as plt |
|
|
| from langchain.docstore.document import Document |
| from langchain.document_loaders import MWDumpLoader |
|
|
| |
| root_path = "/data/jon/h2o-llm" |
|
|
|
|
| def unescape(x): |
| try: |
| x = ast.literal_eval(x) |
| except: |
| try: |
| x = x.encode('ascii', 'ignore').decode('unicode_escape') |
| except: |
| pass |
| return x |
|
|
|
|
| def get_views(): |
| |
| views = pd.read_csv('wiki_page_views_more_5000month.csv') |
| views.index = views['title'] |
| views = views['views'] |
| views = views.to_dict() |
| views = {str(unescape(str(k))): v for k, v in views.items()} |
| views2 = {k.replace('_', ' '): v for k, v in views.items()} |
| |
| views.update(views2) |
| return views |
|
|
|
|
| class MWDumpDirectLoader(MWDumpLoader): |
| def __init__(self, data: str, encoding: Optional[str] = "utf8", |
| title_words_limit=None, use_views=True, verbose=True): |
| """Initialize with file path.""" |
| self.data = data |
| self.encoding = encoding |
| self.title_words_limit = title_words_limit |
| self.verbose = verbose |
| if use_views: |
| |
| |
| self.views = global_views |
| else: |
| self.views = None |
|
|
| def load(self) -> List[Document]: |
| """Load from file path.""" |
| import mwparserfromhell |
| import mwxml |
|
|
| dump = mwxml.Dump.from_page_xml(self.data) |
|
|
| docs = [] |
|
|
| for page in dump.pages: |
| if self.views is not None and page.title not in self.views: |
| if self.verbose: |
| print("Skipped %s low views" % page.title, flush=True) |
| continue |
| for revision in page: |
| if self.title_words_limit is not None: |
| num_words = len(' '.join(page.title.split('_')).split(' ')) |
| if num_words > self.title_words_limit: |
| if self.verbose: |
| print("Skipped %s" % page.title, flush=True) |
| continue |
| if self.verbose: |
| if self.views is not None: |
| print("Kept %s views: %s" % (page.title, self.views[page.title]), flush=True) |
| else: |
| print("Kept %s" % page.title, flush=True) |
|
|
| code = mwparserfromhell.parse(revision.text) |
| text = code.strip_code( |
| normalize=True, collapse=True, keep_template_params=False |
| ) |
| title_url = str(page.title).replace(' ', '_') |
| metadata = dict(title=page.title, |
| source="https://en.wikipedia.org/wiki/" + title_url, |
| id=page.id, |
| redirect=page.redirect, |
| views=self.views[page.title] if self.views is not None else -1, |
| ) |
| metadata = {k: v for k, v in metadata.items() if v is not None} |
| docs.append(Document(page_content=text, metadata=metadata)) |
|
|
| return docs |
|
|
|
|
| def search_index(search_term, index_filename): |
| byte_flag = False |
| data_length = start_byte = 0 |
| index_file = open(index_filename, 'r') |
| csv_reader = csv.reader(index_file, delimiter=':') |
| for line in csv_reader: |
| if not byte_flag and search_term == line[2]: |
| start_byte = int(line[0]) |
| byte_flag = True |
| elif byte_flag and int(line[0]) != start_byte: |
| data_length = int(line[0]) - start_byte |
| break |
| index_file.close() |
| return start_byte, data_length |
|
|
|
|
| def get_start_bytes(index_filename): |
| index_file = open(index_filename, 'r') |
| csv_reader = csv.reader(index_file, delimiter=':') |
| start_bytes = set() |
| for line in csv_reader: |
| start_bytes.add(int(line[0])) |
| index_file.close() |
| return sorted(start_bytes) |
|
|
|
|
| def get_wiki_filenames(): |
| |
| |
| base_path = os.path.join(root_path, 'enwiki-20230401-pages-articles-multistream') |
| index_file = 'enwiki-20230401-pages-articles-multistream-index.txt' |
| index_filename = os.path.join(base_path, index_file) |
| wiki_filename = os.path.join(base_path, 'enwiki-20230401-pages-articles-multistream.xml.bz2') |
| return index_filename, wiki_filename |
|
|
|
|
| def get_documents_by_search_term(search_term): |
| index_filename, wiki_filename = get_wiki_filenames() |
| start_byte, data_length = search_index(search_term, index_filename) |
| with open(wiki_filename, 'rb') as wiki_file: |
| wiki_file.seek(start_byte) |
| data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length)) |
|
|
| loader = MWDumpDirectLoader(data.decode()) |
| documents = loader.load() |
| return documents |
|
|
|
|
| def get_one_chunk(wiki_filename, start_byte, end_byte, return_file=True, |
| title_words_limit=None, |
| use_views=True): |
| data_length = end_byte - start_byte |
| with open(wiki_filename, 'rb') as wiki_file: |
| wiki_file.seek(start_byte) |
| data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length)) |
|
|
| loader = MWDumpDirectLoader(data.decode(), title_words_limit=title_words_limit, |
| use_views=use_views) |
| documents1 = loader.load() |
| if return_file: |
| base_tmp = "temp_wiki" |
| if not os.path.isdir(base_tmp): |
| os.makedirs(base_tmp, exist_ok=True) |
| filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") |
| with open(filename, 'wb') as f: |
| pickle.dump(documents1, f) |
| return filename |
| return documents1 |
|
|
|
|
| from joblib import Parallel, delayed |
|
|
| global_views = get_views() |
|
|
|
|
| def get_all_documents(small_test=2, n_jobs=None, use_views=True): |
| print("DO get all wiki docs: %s" % small_test, flush=True) |
| index_filename, wiki_filename = get_wiki_filenames() |
| start_bytes = get_start_bytes(index_filename) |
| end_bytes = start_bytes[1:] |
| start_bytes = start_bytes[:-1] |
|
|
| if small_test: |
| start_bytes = start_bytes[:small_test] |
| end_bytes = end_bytes[:small_test] |
| if n_jobs is None: |
| n_jobs = 5 |
| else: |
| if n_jobs is None: |
| n_jobs = os.cpu_count() // 4 |
|
|
| |
| return_file = True |
| documents = Parallel(n_jobs=n_jobs, verbose=10, backend='multiprocessing')( |
| delayed(get_one_chunk)(wiki_filename, start_byte, end_byte, |
| return_file=return_file, use_views=use_views) for start_byte, end_byte in |
| zip(start_bytes, end_bytes)) |
| if return_file: |
| |
| files = documents.copy() |
| documents = [] |
| for fil in files: |
| with open(fil, 'rb') as f: |
| documents.extend(pickle.load(f)) |
| os.remove(fil) |
| else: |
| from functools import reduce |
| from operator import concat |
| documents = reduce(concat, documents) |
| assert isinstance(documents, list) |
|
|
| print("DONE get all wiki docs", flush=True) |
| return documents |
|
|
|
|
| def test_by_search_term(): |
| search_term = 'Apollo' |
| assert len(get_documents_by_search_term(search_term)) == 100 |
|
|
| search_term = 'Abstract (law)' |
| assert len(get_documents_by_search_term(search_term)) == 100 |
|
|
| search_term = 'Artificial languages' |
| assert len(get_documents_by_search_term(search_term)) == 100 |
|
|
|
|
| def test_start_bytes(): |
| index_filename, wiki_filename = get_wiki_filenames() |
| assert len(get_start_bytes(index_filename)) == 227850 |
|
|
|
|
| def test_get_all_documents(): |
| small_test = 20 |
| n_jobs = os.cpu_count() // 4 |
|
|
| assert len(get_all_documents(small_test=small_test, n_jobs=n_jobs, use_views=False)) == small_test * 100 |
|
|
| assert len(get_all_documents(small_test=small_test, n_jobs=n_jobs, use_views=True)) == 429 |
|
|
|
|
| def get_one_pageviews(fil): |
| df1 = pd.read_csv(fil, sep=' ', header=None, names=['region', 'title', 'views', 'foo'], quoting=csv.QUOTE_NONE) |
| df1.index = df1['title'] |
| df1 = df1[df1['region'] == 'en'] |
| df1 = df1.drop('region', axis=1) |
| df1 = df1.drop('foo', axis=1) |
| df1 = df1.drop('title', axis=1) |
|
|
| base_tmp = "temp_wiki_pageviews" |
| if not os.path.isdir(base_tmp): |
| os.makedirs(base_tmp, exist_ok=True) |
| filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv") |
| df1.to_csv(filename, index=True) |
| return filename |
|
|
|
|
| def test_agg_pageviews(gen_files=False): |
| if gen_files: |
| path = os.path.join(root_path, 'wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04') |
| files = glob.glob(os.path.join(path, 'pageviews*.gz')) |
| |
| n_jobs = os.cpu_count() // 2 |
| csv_files = Parallel(n_jobs=n_jobs, verbose=10, backend='multiprocessing')( |
| delayed(get_one_pageviews)(fil) for fil in files) |
| else: |
| |
| csv_files = glob.glob(os.path.join(root_path, 'temp_wiki_pageviews/*.csv')) |
|
|
| df_list = [] |
| for csv_file in csv_files: |
| print(csv_file) |
| df1 = pd.read_csv(csv_file) |
| df_list.append(df1) |
| df = pd.concat(df_list, axis=0) |
| df = df.groupby('title')['views'].sum().reset_index() |
| df.to_csv("wiki_page_views.csv", index=True) |
|
|
|
|
| def test_reduce_pageview(): |
| filename = "wiki_page_views.csv" |
| df = pd.read_csv(filename) |
| df = df[df['views'] < 1e7] |
| |
| plt.hist(df['views'], bins=100, log=True) |
| views_avg = np.mean(df['views']) |
| views_median = np.median(df['views']) |
| plt.title("Views avg: %s median: %s" % (views_avg, views_median)) |
| plt.savefig(filename.replace('.csv', '.png')) |
| plt.close() |
| |
| views_limit = 5000 |
| df = df[df['views'] > views_limit] |
| filename = "wiki_page_views_more_5000month.csv" |
| df.to_csv(filename, index=True) |
| |
| plt.hist(df['views'], bins=100, log=True) |
| views_avg = np.mean(df['views']) |
| views_median = np.median(df['views']) |
| plt.title("Views avg: %s median: %s" % (views_avg, views_median)) |
| plt.savefig(filename.replace('.csv', '.png')) |
| plt.close() |
|
|
|
|
| @pytest.mark.skip("Only if doing full processing again, some manual steps") |
| def test_do_wiki_full_all(): |
| |
| |
|
|
| |
| |
| |
|
|
| |
| os.system("wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2") |
|
|
| |
| test_get_all_documents() |
|
|
| |
| test_by_search_term() |
|
|
| |
| test_start_bytes() |
|
|
| |
| os.system("wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/") |
|
|
| |
| test_agg_pageviews(gen_files=True) |
|
|
| |
| test_reduce_pageview() |
|
|
| |
| |
| |
| """ |
| python generate.py --langchain_mode='wiki_full' --langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log |
| """ |
|
|