| #pragma once |
|
|
| #include "llama.h" |
|
|
| #include <array> |
| #include <vector> |
|
|
| |
| |
| struct llama_ubatch { |
| bool equal_seqs; |
| |
|
|
| uint32_t n_tokens; |
| uint32_t n_seq_tokens; |
| uint32_t n_seqs; |
|
|
| llama_token * token; |
| float * embd; |
| llama_pos * pos; |
| int32_t * n_seq_id; |
| llama_seq_id ** seq_id; |
| int8_t * output; |
| }; |
|
|
| struct llama_sbatch_seq { |
| int32_t n_seq_id; |
|
|
| llama_seq_id * seq_id; |
|
|
| size_t offset; |
| size_t length; |
| }; |
|
|
| |
| struct llama_sbatch { |
| |
| size_t n_tokens; |
|
|
| size_t n_embd; |
|
|
| bool logits_all; |
|
|
| |
| std::vector<size_t> ids; |
| |
| std::vector<size_t> out_ids; |
| std::vector<llama_sbatch_seq> seq; |
|
|
| const llama_batch * batch = nullptr; |
|
|
| |
| std::vector<llama_token> ubatch_token; |
| std::vector<float> ubatch_embd; |
| std::vector<llama_pos> ubatch_pos; |
| std::vector<int32_t> ubatch_n_seq_id; |
| std::vector<llama_seq_id *> ubatch_seq_id; |
| std::vector<int8_t> ubatch_output; |
|
|
| llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); |
|
|
| void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); |
|
|
| |
| llama_ubatch split_simple(size_t n_ubatch); |
|
|
| |
| llama_ubatch split_equal(size_t n_ubatch); |
|
|
| |
| llama_ubatch split_seq(size_t n_ubatch); |
|
|
| void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); |
| }; |
|
|
| |
| struct llama_batch_allocr { |
| struct llama_batch batch; |
|
|
| std::array<llama_seq_id, 1> seq_id_0 = { 0 }; |
| std::vector<llama_pos> pos; |
| std::vector<int32_t> n_seq_id; |
| std::vector<llama_seq_id *> seq_id; |
| std::vector<int8_t> logits; |
|
|
| |
| llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |
| }; |
|
|