add stop prompt
This commit is contained in:
parent
4b381f7073
commit
40639fbe0b
6 changed files with 22 additions and 5 deletions
6
Makefile
6
Makefile
|
|
@ -1,9 +1,13 @@
|
|||
BINARY_NAME=kevin
|
||||
|
||||
build:
|
||||
build-llama:
|
||||
cd pkg/llama && make libbinding.a
|
||||
|
||||
build-kevin:
|
||||
go build -o ${BINARY_NAME} main.go
|
||||
|
||||
build: build-llama build-kevin
|
||||
|
||||
run: build
|
||||
./${BINARY_NAME}
|
||||
|
||||
|
|
|
|||
2
main.go
2
main.go
|
|
@ -88,7 +88,7 @@ func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
|||
message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content))
|
||||
prompt := structs.GeneratePrompt(config, messages, *message)
|
||||
var seed int = int(time.Now().Unix())
|
||||
res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed))
|
||||
res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed), llama.SetStopPrompt("###"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
|||
int tokens, int top_k, float top_p, float temp,
|
||||
float repeat_penalty, int repeat_last_n,
|
||||
bool ignore_eos, bool memory_f16, int n_batch,
|
||||
int n_keep) {
|
||||
int n_keep, const char *stopprompt) {
|
||||
gpt_params *params = new gpt_params;
|
||||
params->seed = seed;
|
||||
params->n_threads = threads;
|
||||
|
|
@ -189,6 +189,10 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
|||
params->prompt = prompt;
|
||||
params->ignore_eos = ignore_eos;
|
||||
|
||||
std::string stopprompt_str(stopprompt);
|
||||
|
||||
params->antiprompt.push_back(stopprompt_str);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
|||
int tokens, int top_k, float top_p, float temp,
|
||||
float repeat_penalty, int repeat_last_n,
|
||||
bool ignore_eos, bool memory_f16, int n_batch,
|
||||
int n_keep);
|
||||
int n_keep, const char *stopprompt);
|
||||
|
||||
void llama_free_params(void *params_ptr);
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
|
|||
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
|
||||
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
|
||||
C.bool(po.IgnoreEOS), C.bool(po.F16KV),
|
||||
C.int(po.Batch), C.int(po.NKeep),
|
||||
C.int(po.Batch), C.int(po.NKeep), C.CString(po.StopPrompt),
|
||||
)
|
||||
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
|
||||
if ret != 0 {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ type PredictOptions struct {
|
|||
TopP, Temperature, Penalty float64
|
||||
F16KV bool
|
||||
IgnoreEOS bool
|
||||
StopPrompt string
|
||||
}
|
||||
|
||||
type PredictOption func(p *PredictOptions)
|
||||
|
|
@ -38,6 +39,7 @@ var DefaultOptions PredictOptions = PredictOptions{
|
|||
Repeat: 64,
|
||||
Batch: 8,
|
||||
NKeep: 64,
|
||||
StopPrompt: "",
|
||||
}
|
||||
|
||||
// SetContext sets the context size.
|
||||
|
|
@ -154,6 +156,13 @@ func SetNKeep(n int) PredictOption {
|
|||
}
|
||||
}
|
||||
|
||||
// SetStopPrompt sets the prompt to stop generation.
|
||||
func SetStopPrompt(prompt string) PredictOption {
|
||||
return func(p *PredictOptions) {
|
||||
p.StopPrompt = prompt
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new PredictOptions object with the given options.
|
||||
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
||||
p := DefaultOptions
|
||||
|
|
|
|||
Reference in a new issue