| --- |
| library_name: transformers |
| datasets: |
| - youngermax/text-tagging |
| --- |
| |
| ## Model Details |
|
|
| ### Model Description |
|
|
| This model identifies multiple topics related to the text in natural language. It is finetuned on youngermax/text-tagging for 3.5 epoch over ~1.3 hours on a free Kaggle P100. |
|
|
| - **Developed by:** Lincoln Maxwell |
| - **Model type:** Generative Pretrained Transformer |
| - **Language(s) (NLP):** English |
| - **Finetuned from model:** DistilGPT2 |
|
|
| ## Uses |
|
|
| <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
|
| ### Direct Use |
|
|
| <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. --> |
|
|
| ```python |
| |
| input_ids = tokenizer.encode(prompt + '<|topic|>', return_tensors='pt').to('cuda') |
| |
| # Generate text |
| output = model.generate( |
| input_ids, |
| max_length=1024, |
| num_return_sequences=1, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.eos_token_id, |
| top_k=100, |
| top_p=0.5, |
| temperature=1 |
| ) |
| |
| # Decode the output |
| text = tokenizer.decode(output[0], skip_special_tokens=False, early_stopping=True) |
| text = text[len(prompt):text.find('<|endoftext|>')] |
| |
| topics = list(set(list(map(lambda x: x.strip(), text.split('<|topic|>')))[1:])) |
| ``` |