PooryaPiroozfar commited on
Commit
422794d
·
verified ·
1 Parent(s): f9286ff

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -0
  2. app.py +290 -0
  3. frame_triples.xlsx +0 -0
  4. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ WORKDIR /app
7
+
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ build-essential \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ COPY requirements.txt .
14
+ RUN pip install --upgrade pip \
15
+ && pip install --no-cache-dir -r requirements.txt
16
+
17
+ COPY . .
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import re
5
+ import json
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import pandas as pd
11
+ import gradio as gr
12
+
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModel,
16
+ AutoModelForTokenClassification
17
+ )
18
+ from huggingface_hub import snapshot_download
19
+
20
+ # -------------------------
21
+ # تنظیمات کلی
22
+ # -------------------------
23
+ device = torch.device("cpu")
24
+
25
+ FRAME_DET_REPO = "PooryaPiroozfar/frame-detection-parsbert"
26
+ FE_REPO = "PooryaPiroozfar/srl-frame-elements-parsbert"
27
+
28
+ FRAME_DET_DIR = "models/frame_detection"
29
+ FE_BASE_DIR = "models/frame_elements"
30
+
31
+ TRIPLES_PATH = "frame_triples.xlsx"
32
+ THRESHOLD = 0.25
33
+
34
+ frame_names = [
35
+ "Activity_finish","Activity_start","Aging","Attaching","Attempt",
36
+ "Becoming","Being_born","Borrowing","Causation","Chatting",
37
+ "Choosing","Closure","Clothing","Cutting","Damaging","Desiring","Discussion",
38
+ "Emphasizing","Food","Installing","Locating","Memory","Morality_evaluation",
39
+ "Motion","Offering","Practice","Project","Publishing","Religious_belief",
40
+ "Removing","Request","Residence","Sharing","Taking","Telling","Travel",
41
+ "Using","Visiting","Waiting","Work"
42
+ ]
43
+
44
+ # -------------------------
45
+ # دانلود مدل‌ها
46
+ # -------------------------
47
+ if not os.path.exists(FRAME_DET_DIR):
48
+ snapshot_download(repo_id=FRAME_DET_REPO, local_dir=FRAME_DET_DIR)
49
+
50
+ if not os.path.exists(FE_BASE_DIR):
51
+ snapshot_download(repo_id=FE_REPO, local_dir=FE_BASE_DIR)
52
+
53
+ # -------------------------
54
+ # Encoder
55
+ # -------------------------
56
+ encoder_name = "HooshvareLab/bert-base-parsbert-uncased"
57
+ sent_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
58
+ sent_encoder = AutoModel.from_pretrained(encoder_name).to(device)
59
+ sent_encoder.eval()
60
+
61
+ def get_embedding(text):
62
+ inputs = sent_tokenizer(
63
+ text,
64
+ return_tensors="pt",
65
+ truncation=True,
66
+ padding=True,
67
+ max_length=128
68
+ ).to(device)
69
+
70
+ with torch.no_grad():
71
+ outputs = sent_encoder(**inputs)
72
+
73
+ token_embeddings = outputs.last_hidden_state
74
+ mask = inputs["attention_mask"].unsqueeze(-1).float()
75
+ summed = torch.sum(token_embeddings * mask, dim=1)
76
+ lengths = torch.clamp(mask.sum(dim=1), min=1e-9)
77
+
78
+ return (summed / lengths).squeeze(0)
79
+
80
+ # -------------------------
81
+ # Frame Detection Model
82
+ # -------------------------
83
+ class FrameSimilarityModel(nn.Module):
84
+ def __init__(self, emb_dim, frame_emb_init):
85
+ super().__init__()
86
+ self.proj = nn.Linear(emb_dim, emb_dim)
87
+ self.frame_embeddings = nn.Parameter(
88
+ torch.tensor(frame_emb_init, dtype=torch.float32)
89
+ )
90
+
91
+ def forward(self, sent_emb):
92
+ sent_proj = F.normalize(self.proj(sent_emb), dim=-1)
93
+ frames = F.normalize(self.frame_embeddings, dim=-1)
94
+ return torch.matmul(sent_proj, frames.T)
95
+
96
+ frame_embs = np.load(os.path.join(FRAME_DET_DIR, "trained_frame_embeddings.npy"))
97
+
98
+ frame_model = FrameSimilarityModel(
99
+ emb_dim=768,
100
+ frame_emb_init=frame_embs
101
+ ).to(device)
102
+
103
+ frame_model.load_state_dict(
104
+ torch.load(
105
+ os.path.join(FRAME_DET_DIR, "best_frame_margin_model.pt"),
106
+ map_location="cpu"
107
+ )
108
+ )
109
+ frame_model.eval()
110
+
111
+ def predict_frame(sentence):
112
+ emb = get_embedding(sentence).unsqueeze(0)
113
+ with torch.no_grad():
114
+ sims = frame_model(emb)
115
+ max_sim, idx = torch.max(sims, dim=1)
116
+
117
+ if max_sim.item() < THRESHOLD:
118
+ return None, max_sim.item()
119
+
120
+ return frame_names[idx.item()], max_sim.item()
121
+
122
+ # -------------------------
123
+ # Frame Elements
124
+ # -------------------------
125
+ def predict_frame_elements(sentence, frame_name):
126
+ frame_dir = os.path.join(FE_BASE_DIR, frame_name)
127
+ if not os.path.exists(frame_dir):
128
+ return []
129
+
130
+ with open(os.path.join(frame_dir, "label2id.json"), encoding="utf-8") as f:
131
+ label2id = json.load(f)
132
+ id2label = {int(v): k for k, v in label2id.items()}
133
+
134
+ tokenizer = AutoTokenizer.from_pretrained(frame_dir)
135
+ model = AutoModelForTokenClassification.from_pretrained(
136
+ frame_dir,
137
+ num_labels=len(label2id),
138
+ id2label=id2label,
139
+ label2id=label2id
140
+ ).to(device)
141
+ model.eval()
142
+
143
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=128)
144
+
145
+ with torch.no_grad():
146
+ outputs = model(**inputs)
147
+
148
+ preds = torch.argmax(outputs.logits, dim=-1).squeeze(0).numpy()
149
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
150
+
151
+ elements = []
152
+ for tok, lab_id in zip(tokens, preds):
153
+ if tok in {"[CLS]", "[SEP]", "[PAD]"}:
154
+ continue
155
+ label = id2label[lab_id]
156
+ if label != "O":
157
+ elements.append((tok, label))
158
+
159
+ return elements
160
+
161
+ # -------------------------
162
+ # Triple Extraction
163
+ # -------------------------
164
+ triples_df = pd.read_excel(TRIPLES_PATH)
165
+
166
+ def group_elements(elements):
167
+ d = {}
168
+ for tok, lab in elements:
169
+ d.setdefault(lab, []).append(tok)
170
+ return d
171
+
172
+ def extract_relations(frame_name, elements):
173
+ fe_dict = group_elements(elements)
174
+ rows = triples_df[triples_df["Frame"] == frame_name]
175
+
176
+ relations = []
177
+ for _, r in rows.iterrows():
178
+ if r["Subject"] in fe_dict and r["Object"] in fe_dict:
179
+ for s in fe_dict[r["Subject"]]:
180
+ for o in fe_dict[r["Object"]]:
181
+ relations.append({
182
+ "subject": s,
183
+ "relation": r["Relation"],
184
+ "object": o,
185
+ "subject_fe": r["Subject"],
186
+ "object_fe": r["Object"]
187
+ })
188
+ return relations
189
+
190
+ # -------------------------
191
+ # Sentence Utilities
192
+ # -------------------------
193
+ def split_sentences(text):
194
+ sents = re.split(r'[.!؟\n]+', text)
195
+ return [s.strip() for s in sents if s.strip()]
196
+
197
+ CONDITIONAL_PATTERNS = [
198
+ r'^اگر\s', r'\sاگر\s', r'^در صورت\s',
199
+ r'^چنانچه\s', r'^هرگاه\s'
200
+ ]
201
+
202
+ def detect_conditional(sentence):
203
+ for p in CONDITIONAL_PATTERNS:
204
+ if re.search(p, sentence):
205
+ return True
206
+ return False
207
+
208
+ def split_condition(sentence):
209
+ for sep in ['،', ',']:
210
+ if sep in sentence:
211
+ return sentence.split(sep, 1)
212
+ return None, None
213
+
214
+ def build_spin_rule(if_triples, then_triples, rule_id):
215
+ return {
216
+ "rule_id": f"Rule_{rule_id}",
217
+ "type": "SPIN",
218
+ "if": if_triples,
219
+ "then": then_triples
220
+ }
221
+
222
+ # -------------------------
223
+ # Pipeline
224
+ # -------------------------
225
+ def analyze_sentence(sentence):
226
+ frame, sim = predict_frame(sentence)
227
+
228
+ if frame is None:
229
+ return {
230
+ "sentence": sentence,
231
+ "frame": "خارج از دامنه",
232
+ "similarity": round(sim, 3),
233
+ "elements": [],
234
+ "relations": [],
235
+ "is_conditional": False,
236
+ "rule": None
237
+ }
238
+
239
+ elements = predict_frame_elements(sentence, frame)
240
+ relations = extract_relations(frame, elements)
241
+
242
+ is_cond = detect_conditional(sentence)
243
+ rule = None
244
+
245
+ if is_cond:
246
+ cond_part, res_part = split_condition(sentence)
247
+ if cond_part and res_part:
248
+ cond_res = analyze_sentence(cond_part)
249
+ res_res = analyze_sentence(res_part)
250
+ rule = build_spin_rule(
251
+ cond_res["relations"],
252
+ res_res["relations"],
253
+ rule_id=abs(hash(sentence)) % 10000
254
+ )
255
+
256
+ return {
257
+ "sentence": sentence,
258
+ "frame": frame,
259
+ "similarity": round(sim, 3),
260
+ "elements": elements,
261
+ "relations": relations,
262
+ "is_conditional": is_cond,
263
+ "rule": rule
264
+ }
265
+
266
+ def analyze_text(text):
267
+ sentences = split_sentences(text)
268
+ return {
269
+ "input_text": text,
270
+ "sentences_analysis": [
271
+ analyze_sentence(s) for s in sentences
272
+ ]
273
+ }
274
+
275
+ # -------------------------
276
+ # Gradio UI
277
+ # -------------------------
278
+ demo = gr.Interface(
279
+ fn=analyze_text,
280
+ inputs=gr.Textbox(
281
+ label="متن فارسی",
282
+ placeholder="مثال: اگر علی از تهران به مشهد سفر کند، در هتل اقامت می‌کند."
283
+ ),
284
+ outputs=gr.JSON(label="خروجی"),
285
+ title="Persian Semantic Frame, Triple & Rule Extractor",
286
+ description="تشخیص فریم، عناصر فریم، استخراج triple و قوانین شرطی (SPIN)"
287
+ )
288
+
289
+ if __name__ == "__main__":
290
+ demo.launch(server_name="0.0.0.0", server_port=7860)
frame_triples.xlsx ADDED
Binary file (25.6 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentencepiece
4
+ pandas
5
+ numpy
6
+ openpyxl
7
+ gradio
8
+ huggingface_hub