PCCoT

This is a research-purpose finetuned Llama-3.2-1B-Instruct model described in paper "Parallel Continuous Chain-of-Thought with Jacobi Iteration".

About

Parallel Continuous Chain-of-Thought (PCCoT) parallelizes the continuous chain-of-thought reasoning process through Jacobi iteration. Instead of generating the latent thought tokens sequentially, it iteratively updates all latent thought tokens in parallel, significantly improving the training and inference efficiency of continuous chain-of-thought (CoT) reasoning with better performance on reasoning tasks. See more details in our Github repo: https://github.com/whyNLP/PCCoT

Quick Start

The use of this model relies on the code in our Github repo. See more detailed instructions in the repo.

Below is a quick start example to showcase how to use this model for solving a math word problem.

import models # from the github repo
from transformers import AutoTokenizer, AutoConfig, HfArgumentParser
from transformers.utils.hub import cached_file
from peft import AutoPeftModel

# Example model name
model_name_or_path = "whyNLP/pccot-llama1b"
# Example question
question = "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?"


# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path)
model = AutoPeftModel.from_pretrained(model_name_or_path)

# we have to override the model config after loading the model, peft does not provide interface to
# load base model with custom config with AutoPeftModel.
model.get_base_model().config = config

# Load the PCCoT arguments
pccot_args_file = cached_file(model_name_or_path, models.PCCOT_ARGS_NAME)
parser = HfArgumentParser(models.PCCoTArguments)
(pccot_args, ) = parser.parse_json_file(json_file=pccot_args_file)

# Load the data processor
data_processor = models.COTDataProcessor(
    tokenizer=tokenizer,
    pccot_args=pccot_args,
)
collated = data_processor.process(question)

# generation
decoded_tokens = model.generate(
    collated=collated,
    max_new_tokens=10,
    do_sample=False,
    # temperature=0.1,  # you can adjust the temperature for more diverse answers
    # top_p=0.9,  # you can adjust the top_p for more diverse answers
    # top_k=50,  # you can adjust the top_k for more diverse answers
)

decoded_tokens = decoded_tokens[:, collated["input_ids"].shape[1]:]  # remove the input_ids part
answers = tokenizer.batch_decode(decoded_tokens, skip_special_tokens=True)

# Print the answer
print("Question:", question)
print("Answer:", answers[0]) # 20

Hyperparameters

The model is finetuned using LoRA with the following hyperparameters:

Hyperparameter Value
LoRA Rank rr 128
LoRA α\alpha 32
LoRA dropout 0.1
LoRA bias False
LoRA target modules ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
Learning Rate 8e-4
Weight Decay 1e-1

See more details in the paper.

Disclaimer: The weights of this model is not the original weights of the model trained when we write the paper, but a reproduction on 1 A6000 GPU. The original weights have lost and I would like to specifically thank my co-author Zhihao Teng to reproduce the model and share the reproduced weights.

Performance

The model is finetuned on the GSM8K-AUG dataset, achieving the following performance:

Model Accuracy
CoT (llama1b) 61.6
PCCoT-gpt2 49.48 ±\pm 0.31
PCCoT-gpt2-nl 49.23 ±\pm 0.80
PCCoT-llama1b 53.35 ±\pm 0.18
Downloads last month
473
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for whynlp/pccot-llama1b

Finetuned
(1)
this model

Dataset used to train whynlp/pccot-llama1b

Paper for whynlp/pccot-llama1b