From 40639fbe0b0e288d6e303075585afddb8f285d10 Mon Sep 17 00:00:00 2001 From: ItzYanick Date: Tue, 18 Apr 2023 15:00:44 +0200 Subject: [PATCH] add stop prompt --- Makefile | 6 +++++- main.go | 2 +- pkg/llama/binding.cpp | 6 +++++- pkg/llama/binding.h | 2 +- pkg/llama/llama.go | 2 +- pkg/llama/options.go | 9 +++++++++ 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 1286b43..63349a7 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,13 @@ BINARY_NAME=kevin -build: +build-llama: cd pkg/llama && make libbinding.a + +build-kevin: go build -o ${BINARY_NAME} main.go +build: build-llama build-kevin + run: build ./${BINARY_NAME} diff --git a/main.go b/main.go index ae54aa1..637642f 100644 --- a/main.go +++ b/main.go @@ -88,7 +88,7 @@ func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content)) prompt := structs.GeneratePrompt(config, messages, *message) var seed int = int(time.Now().Unix()) - res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed)) + res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed), llama.SetStopPrompt("###")) if err != nil { panic(err) } diff --git a/pkg/llama/binding.cpp b/pkg/llama/binding.cpp index 1a2b8d2..182d658 100644 --- a/pkg/llama/binding.cpp +++ b/pkg/llama/binding.cpp @@ -171,7 +171,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads, int tokens, int top_k, float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, - int n_keep) { + int n_keep, const char *stopprompt) { gpt_params *params = new gpt_params; params->seed = seed; params->n_threads = threads; @@ -189,6 +189,10 @@ void *llama_allocate_params(const char *prompt, int seed, int threads, params->prompt = prompt; params->ignore_eos = ignore_eos; + std::string stopprompt_str(stopprompt); + + params->antiprompt.push_back(stopprompt_str); + return params; } diff --git a/pkg/llama/binding.h b/pkg/llama/binding.h index dc8a3f9..8a43735 100644 --- a/pkg/llama/binding.h +++ b/pkg/llama/binding.h @@ -11,7 +11,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads, int tokens, int top_k, float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, - int n_keep); + int n_keep, const char *stopprompt); void llama_free_params(void *params_ptr); diff --git a/pkg/llama/llama.go b/pkg/llama/llama.go index 1576ad4..2ce1b24 100644 --- a/pkg/llama/llama.go +++ b/pkg/llama/llama.go @@ -42,7 +42,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), - C.int(po.Batch), C.int(po.NKeep), + C.int(po.Batch), C.int(po.NKeep), C.CString(po.StopPrompt), ) ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0]))) if ret != 0 { diff --git a/pkg/llama/options.go b/pkg/llama/options.go index 1f2ae9b..9c1abbc 100644 --- a/pkg/llama/options.go +++ b/pkg/llama/options.go @@ -15,6 +15,7 @@ type PredictOptions struct { TopP, Temperature, Penalty float64 F16KV bool IgnoreEOS bool + StopPrompt string } type PredictOption func(p *PredictOptions) @@ -38,6 +39,7 @@ var DefaultOptions PredictOptions = PredictOptions{ Repeat: 64, Batch: 8, NKeep: 64, + StopPrompt: "", } // SetContext sets the context size. @@ -154,6 +156,13 @@ func SetNKeep(n int) PredictOption { } } +// SetStopPrompt sets the prompt to stop generation. +func SetStopPrompt(prompt string) PredictOption { + return func(p *PredictOptions) { + p.StopPrompt = prompt + } +} + // Create a new PredictOptions object with the given options. func NewPredictOptions(opts ...PredictOption) PredictOptions { p := DefaultOptions