Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Implementation of prompt caching #14

Merged
merged 17 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ pub struct Args {
/// for sampling.
#[arg(long, default_value_t = 0.95)]
pub top_p: f32,

/// Stores a cached prompt at the given path. The same prompt can then be
/// loaded from disk using --restore-prompt
#[arg(long, default_value = None)]
pub cache_prompt: Option<String>,

/// Restores a cached prompt at the given path, previously using
/// --cache-prompt
#[arg(long, default_value = None)]
pub restore_prompt: Option<String>,
}

/// CLI args are stored in a lazy static variable so they're accessible from
Expand Down
85 changes: 74 additions & 11 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::io::Write;

use cli_args::CLI_ARGS;
use llama_rs::InferenceParameters;
use llama_rs::{InferenceParameters, InferenceSnapshot};
use rand::thread_rng;

mod cli_args;
Expand All @@ -18,9 +18,8 @@ fn main() {
n_threads: args.num_threads as i32,
n_predict: args.num_predict,
n_batch: args.batch_size,
top_k: args.top_k as i32,
top_k: args.top_k,
top_p: args.top_p,
repeat_last_n: args.repeat_last_n,
repeat_penalty: args.repeat_penalty,
temp: args.temp,
};
Expand All @@ -29,18 +28,18 @@ fn main() {
match std::fs::read_to_string(path) {
Ok(prompt) => prompt,
Err(err) => {
eprintln!("Could not read prompt file at {path}. Error {err}");
log::error!("Could not read prompt file at {path}. Error {err}");
std::process::exit(1);
}
}
} else if let Some(prompt) = &args.prompt {
prompt.clone()
} else {
eprintln!("No prompt or prompt file was provided. See --help");
log::error!("No prompt or prompt file was provided. See --help");
std::process::exit(1);
};

let (model, vocab) =
let (mut model, vocab) =
llama_rs::Model::load(&args.model_path, args.num_ctx_tokens as i32, |progress| {
use llama_rs::LoadProgress;
match progress {
Expand Down Expand Up @@ -97,9 +96,73 @@ fn main() {
log::info!("Model fully loaded!");

let mut rng = thread_rng();
model.inference_with_prompt(&vocab, &inference_params, &prompt, &mut rng, |t| {
print!("{t}");
std::io::stdout().flush().unwrap();
});
println!();

let mut session = if let Some(restore_path) = &args.restore_prompt {
let snapshot = InferenceSnapshot::load_from_disk(restore_path);
match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) {
Ok(session) => {
log::info!("Restored cached memory from {restore_path}");
session
}
Err(err) => {
eprintln!("Could not restore prompt. Error: {err}");
std::process::exit(1);
}
}
} else {
model.start_session(args.repeat_last_n)
};

if let Some(cache_path) = &args.cache_prompt {
let res = session.feed_prompt(&model, &vocab, &inference_params, &prompt, |t| {
print!("{t}");
std::io::stdout().flush().unwrap();
});
println!();
match res {
Ok(_) => (),
Err(llama_rs::Error::ContextFull) => {
log::warn!(
"Context is not large enough to fit the prompt. Saving intermediate state."
);
}
err => unreachable!("{err:?}"),
}

// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
let memory = session.get_snapshot();
match memory.write_to_disk(cache_path) {
Ok(_) => {
log::info!("Successfully written prompt cache to {cache_path}");
}
Err(err) => {
log::error!("Could not write prompt cache at {cache_path}: {err}");
std::process::exit(1);
}
}
}
} else {
let res = session.inference_with_prompt(
&model,
&vocab,
&inference_params,
&prompt,
&mut rng,
|t| {
print!("{t}");
std::io::stdout().flush().unwrap();
},
);
println!();

match res {
Ok(_) => (),
Err(llama_rs::Error::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
err => unreachable!("{err:?}"),
}
}
}
4 changes: 3 additions & 1 deletion llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ ggml-raw = { path = "../ggml-raw" }
partial_sort = "0.2.0"
thiserror = "1.0"

rand = { workspace = true }
rand = { workspace = true }
serde = { version = "1.0.156", features = ["derive"] }
bincode = "1.3.3"
5 changes: 5 additions & 0 deletions llama-rs/src/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ impl Tensor {
std::ptr::copy_nonoverlapping(src.as_ptr(), self.data() as *mut u8, src.len())
}

#[allow(unused)]
pub fn zero_data(&self) {
unsafe { std::ptr::write_bytes(self.data() as *mut u8, 0, self.nbytes()) }
}

pub unsafe fn read_data(&self, offset: usize, dst: &mut [u8]) {
let data = unsafe { ggml_raw::ggml_get_data(self.ptr.as_ptr()).add(offset) };
std::ptr::copy_nonoverlapping(data, dst as *mut _ as _, dst.len())
Expand Down
Loading