| #pragma once |
|
|
| #include "llama.h" |
|
|
| #include "llama-cparams.h" |
|
|
| #include <array> |
| #include <vector> |
| #include <set> |
| #include <bitset> |
| #include <memory> |
| #include <unordered_map> |
|
|
| |
| struct llama_ubatch { |
| bool equal_seqs() const { |
| return b_equal_seqs != 0; |
| } |
|
|
| uint32_t b_equal_seqs; |
| |
| |
|
|
| uint32_t n_tokens; |
| uint32_t n_seq_tokens; |
| uint32_t n_seqs; |
| uint32_t n_seqs_unq; |
|
|
| |
| |
| |
|
|
| |
| llama_token * token; |
| float * embd; |
| llama_pos * pos; |
| int32_t * n_seq_id; |
| llama_seq_id ** seq_id; |
| llama_seq_id * seq_id_unq; |
| int32_t * seq_idx; |
| int8_t * output; |
|
|
| struct data_t { |
| std::vector<llama_token> token; |
| std::vector<float> embd; |
| std::vector<llama_pos> pos; |
| std::vector<int32_t> n_seq_id; |
| std::vector<llama_seq_id *> seq_id; |
| std::vector<llama_seq_id> seq_id_unq; |
| std::vector<int32_t> seq_idx; |
| std::vector<int8_t> output; |
| }; |
|
|
| |
| std::shared_ptr<data_t> data; |
| }; |
|
|
| |
| class llama_batch_allocr { |
| public: |
| llama_batch_allocr(uint32_t n_pos_per_embd); |
|
|
| |
| |
| bool init( |
| const llama_batch & batch_inp, |
| const llama_vocab & vocab, |
| const llama_memory_i * memory, |
| uint32_t n_embd, |
| uint32_t n_seq_max, |
| bool output_all); |
|
|
| const llama_batch & get_batch() const; |
|
|
| uint32_t get_n_tokens() const; |
| uint32_t get_n_outputs() const; |
| uint32_t get_n_used() const; |
|
|
| |
| std::vector<int32_t> & get_out_ids(); |
|
|
| |
| llama_pos seq_pos_min(llama_seq_id seq_id) const; |
| llama_pos seq_pos_max(llama_seq_id seq_id) const; |
|
|
| |
| void split_reset(); |
|
|
| |
| llama_ubatch split_simple(uint32_t n_ubatch); |
|
|
| |
| |
| llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); |
|
|
| |
| llama_ubatch split_seq(uint32_t n_ubatch); |
|
|
| |
| |
| llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
|
|
| private: |
| void clear(); |
|
|
| |
| |
| llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
|
|
| |
| void ubatch_print(const llama_ubatch & ubatch, int debug); |
|
|
| llama_batch batch; |
|
|
| |
| const llama_vocab * vocab; |
|
|
| |
| |
| const uint32_t n_pos_per_embd; |
|
|
| uint32_t n_embd; |
| uint32_t n_seq_max; |
| uint32_t n_outputs; |
|
|
| std::array<llama_seq_id, 1> seq_id_0 = { 0 }; |
|
|
| std::vector<llama_pos> pos; |
| std::vector<int32_t> n_seq_id; |
| std::vector<llama_seq_id *> seq_id; |
| std::vector<llama_seq_id> seq_id_unq; |
| std::vector<int32_t> seq_idx; |
| std::vector<int8_t> output; |
|
|
| using pos_set_t = std::set<llama_pos>; |
| using seq_cpl_t = std::vector<bool>; |
|
|
| |
| bool has_cpl = false; |
|
|
| std::vector<pos_set_t> seq_pos; |
| std::vector<seq_cpl_t> seq_cpl; |
|
|
| using idx_vec_t = std::vector<int32_t>; |
| using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
|
|
| std::vector<seq_set_t> seq_set; |
|
|
| std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; |
|
|
| |
| std::vector<int32_t> out_ids; |
|
|
| uint32_t n_used; |
|
|
| |
| std::vector<bool> used; |
|
|
| int debug; |
| }; |
|
|