add stop prompt

This commit is contained in:
ItzYanick 2023-04-18 15:00:44 +02:00
parent 4b381f7073
commit 40639fbe0b
No known key found for this signature in database
GPG key ID: 0E3DB1F28A357B8A
6 changed files with 22 additions and 5 deletions

View file

@ -1,9 +1,13 @@
BINARY_NAME=kevin BINARY_NAME=kevin
build: build-llama:
cd pkg/llama && make libbinding.a cd pkg/llama && make libbinding.a
build-kevin:
go build -o ${BINARY_NAME} main.go go build -o ${BINARY_NAME} main.go
build: build-llama build-kevin
run: build run: build
./${BINARY_NAME} ./${BINARY_NAME}

View file

@ -88,7 +88,7 @@ func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content)) message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content))
prompt := structs.GeneratePrompt(config, messages, *message) prompt := structs.GeneratePrompt(config, messages, *message)
var seed int = int(time.Now().Unix()) var seed int = int(time.Now().Unix())
res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed)) res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed), llama.SetStopPrompt("###"))
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -171,7 +171,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
int tokens, int top_k, float top_p, float temp, int tokens, int top_k, float top_p, float temp,
float repeat_penalty, int repeat_last_n, float repeat_penalty, int repeat_last_n,
bool ignore_eos, bool memory_f16, int n_batch, bool ignore_eos, bool memory_f16, int n_batch,
int n_keep) { int n_keep, const char *stopprompt) {
gpt_params *params = new gpt_params; gpt_params *params = new gpt_params;
params->seed = seed; params->seed = seed;
params->n_threads = threads; params->n_threads = threads;
@ -189,6 +189,10 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
params->prompt = prompt; params->prompt = prompt;
params->ignore_eos = ignore_eos; params->ignore_eos = ignore_eos;
std::string stopprompt_str(stopprompt);
params->antiprompt.push_back(stopprompt_str);
return params; return params;
} }

View file

@ -11,7 +11,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
int tokens, int top_k, float top_p, float temp, int tokens, int top_k, float top_p, float temp,
float repeat_penalty, int repeat_last_n, float repeat_penalty, int repeat_last_n,
bool ignore_eos, bool memory_f16, int n_batch, bool ignore_eos, bool memory_f16, int n_batch,
int n_keep); int n_keep, const char *stopprompt);
void llama_free_params(void *params_ptr); void llama_free_params(void *params_ptr);

View file

@ -42,7 +42,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.bool(po.IgnoreEOS), C.bool(po.F16KV),
C.int(po.Batch), C.int(po.NKeep), C.int(po.Batch), C.int(po.NKeep), C.CString(po.StopPrompt),
) )
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0]))) ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
if ret != 0 { if ret != 0 {

View file

@ -15,6 +15,7 @@ type PredictOptions struct {
TopP, Temperature, Penalty float64 TopP, Temperature, Penalty float64
F16KV bool F16KV bool
IgnoreEOS bool IgnoreEOS bool
StopPrompt string
} }
type PredictOption func(p *PredictOptions) type PredictOption func(p *PredictOptions)
@ -38,6 +39,7 @@ var DefaultOptions PredictOptions = PredictOptions{
Repeat: 64, Repeat: 64,
Batch: 8, Batch: 8,
NKeep: 64, NKeep: 64,
StopPrompt: "",
} }
// SetContext sets the context size. // SetContext sets the context size.
@ -154,6 +156,13 @@ func SetNKeep(n int) PredictOption {
} }
} }
// SetStopPrompt sets the prompt to stop generation.
func SetStopPrompt(prompt string) PredictOption {
return func(p *PredictOptions) {
p.StopPrompt = prompt
}
}
// Create a new PredictOptions object with the given options. // Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions { func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions p := DefaultOptions