merchant-consumption-category-discriminator-v3
๋ชจ๋ธ ๊ฐ์
์ด ๋ชจ๋ธ์ ํ๊ตญ์ด ๊ฐ๋งน์ ๋ช
์ ์
๋ ฅ๋ฐ์ ์๋น ์นดํ
๊ณ ๋ฆฌ๋ฅผ ๋ถ๋ฅํ๊ธฐ ์ํ KoELECTRA ๊ธฐ๋ฐ ๋ถ๋ฅ ๋ชจ๋ธ์
๋๋ค.
๊ฐ๋งน์ ๋ช
์๋ ์
์ข
์ ์ค๋ช
ํ๋ ํต์ฌ ํ ํฐ๊ณผ ์ง์ ๋ช
, ์ง์ญ๋ช
, ์ซ์, ๊ดํธ ๊ฐ์ ๋ถ๊ฐ ์ ๋ณด๊ฐ ํจ๊ป ํฌํจ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์, ์ด๋ฅผ ์์ ์ ์ผ๋ก ๊ตฌ๋ถํ ์ ์๋๋ก ์ค๊ณํ์ต๋๋ค.
๋ชจ๋ธ ์ ๋ณด
- Model ID:
kakao1513/merchant-consumption-category-discriminator-v3 - Base checkpoint:
monologg/koelectra-base-v3-discriminator - Architecture: KoELECTRA encoder + attention pooling classifier
- Max length: 64
- Num labels: 15
- Loading:
trust_remote_code=Trueํ์
์ ์ฒ๋ฆฌ
์ ๋ ฅ ํ ์คํธ์๋ ๋ค์ ์ ์ฒ๋ฆฌ๋ฅผ ์ ์ฉํ์ต๋๋ค.
NFKC์ ๊ทํ- ๊ณต๋ฐฑ/๊ฐํ ์ ๋ฆฌ
casefold๊ธฐ๋ฐ ์ ๊ทํ ํ ์คํธ ์์ฑ
์๋ฅผ ๋ค์ด ์๋์ ๊ฐ์ ํ๊ธฐ ์ฐจ์ด๋ฅผ ์ต๋ํ ์ผ๊ด๋ ํํ๋ก ๋ง์ถ๋๋ก ๊ตฌ์ฑํ์ต๋๋ค.
| ์๋ณธ ์ ๋ ฅ | ์ ๊ทํ ์์ |
|---|---|
๏ผก๏ผข๏ผฃ๋งํธ |
abc๋งํธ |
์คํ๋ฒ
์ค ๊ฐ๋จR์ |
์คํ๋ฒ
์ค ๊ฐ๋จr์ |
๋ฉ๊ฐ์ปคํผ\n(์ฃฝ์ ์ ) |
๋ฉ๊ฐ์ปคํผ (์ฃฝ์ ์ ) |
ํ์ต ๋ฐฉ์
๊ธฐ๋ณธ KoELECTRA์ [CLS] pooling ๋ฐฉ์ ๋์ , ํ ํฐ๋ณ ์ค์๋๋ฅผ ํ์ตํ๋ attention pooling head๋ฅผ ์ฌ์ฉํ์ต๋๋ค.
๋ํ ํ์ต ์์๋ ์ง์ ๋ช
, ์ธต์, ๊ดํธ, ์ง์ญ๋ช
๋ฑ ๋ค์ชฝ ์ ๋ฏธ ์ ๋ณด๋ฅผ ์ผ๋ถ๋ฌ ์๋ suffix/branch noise augmentation์ ์ ์ฉํด, ๊ฐ๋งน์ ๋ช
๋ณํ์ ๋ ๊ฐํ๊ฒ ๋์ํ๋๋ก ๊ตฌ์ฑํ์ต๋๋ค.
์ฑ๋ฅ
service_test f1_macro ๊ธฐ์ค ์ฑ๋ฅ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
| model_variant | f1_macro |
|---|---|
| baseline | 0.7825 |
| v3 | 0.8308 |
์ ๋ ฅ/์ถ๋ ฅ ์์
| ์ ๋ ฅ ๊ฐ๋งน์ ๋ช | ์์ธก ์์ |
|---|---|
์คํ๋ฒ
์ค ๊ฐ๋จR์ |
์นดํ |
๋ค์ด์ ์ฃฝ์ ์ |
์ํ |
๋ฒ๊ฑฐํน ํ๊ต์ญ์ |
์๋น |
์ฌ์ฉ ์์
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_id = "kakao1513/merchant-consumption-category-discriminator-v3"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True)
merchant_text = "์คํ๋ฒ
์ค ๊ฐ๋จR์ "
encoded = tokenizer(
merchant_text,
return_tensors="pt",
truncation=True,
max_length=64,
)
with torch.no_grad():
logits = model(**encoded).logits
predicted_label_id = int(logits.argmax(dim=-1).item())
label_name = model.config.id2label.get(
predicted_label_id,
model.config.id2label.get(str(predicted_label_id))
)
print(label_name)
์ฐธ๊ณ ์ฌํญ
์ด ๋ชจ๋ธ์ ์ปค์คํ attention pooling ํด๋์ค๋ฅผ ํฌํจํ๋ฏ๋ก trust_remote_code=True๊ฐ ํ์ํฉ๋๋ค.
- Downloads last month
- 80