add stop prompt
This commit is contained in:
parent
4b381f7073
commit
40639fbe0b
6 changed files with 22 additions and 5 deletions
6
Makefile
6
Makefile
|
|
@ -1,9 +1,13 @@
|
||||||
BINARY_NAME=kevin
|
BINARY_NAME=kevin
|
||||||
|
|
||||||
build:
|
build-llama:
|
||||||
cd pkg/llama && make libbinding.a
|
cd pkg/llama && make libbinding.a
|
||||||
|
|
||||||
|
build-kevin:
|
||||||
go build -o ${BINARY_NAME} main.go
|
go build -o ${BINARY_NAME} main.go
|
||||||
|
|
||||||
|
build: build-llama build-kevin
|
||||||
|
|
||||||
run: build
|
run: build
|
||||||
./${BINARY_NAME}
|
./${BINARY_NAME}
|
||||||
|
|
||||||
|
|
|
||||||
2
main.go
2
main.go
|
|
@ -88,7 +88,7 @@ func messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content))
|
message := structs.NewMessage(m.Author.Username, strings.TrimSpace(m.Content))
|
||||||
prompt := structs.GeneratePrompt(config, messages, *message)
|
prompt := structs.GeneratePrompt(config, messages, *message)
|
||||||
var seed int = int(time.Now().Unix())
|
var seed int = int(time.Now().Unix())
|
||||||
res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed))
|
res, err := brain.Predict(prompt, llama.SetThreads(4), llama.SetTokens(128), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetSeed(seed), llama.SetStopPrompt("###"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
||||||
int tokens, int top_k, float top_p, float temp,
|
int tokens, int top_k, float top_p, float temp,
|
||||||
float repeat_penalty, int repeat_last_n,
|
float repeat_penalty, int repeat_last_n,
|
||||||
bool ignore_eos, bool memory_f16, int n_batch,
|
bool ignore_eos, bool memory_f16, int n_batch,
|
||||||
int n_keep) {
|
int n_keep, const char *stopprompt) {
|
||||||
gpt_params *params = new gpt_params;
|
gpt_params *params = new gpt_params;
|
||||||
params->seed = seed;
|
params->seed = seed;
|
||||||
params->n_threads = threads;
|
params->n_threads = threads;
|
||||||
|
|
@ -189,6 +189,10 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
||||||
params->prompt = prompt;
|
params->prompt = prompt;
|
||||||
params->ignore_eos = ignore_eos;
|
params->ignore_eos = ignore_eos;
|
||||||
|
|
||||||
|
std::string stopprompt_str(stopprompt);
|
||||||
|
|
||||||
|
params->antiprompt.push_back(stopprompt_str);
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ void *llama_allocate_params(const char *prompt, int seed, int threads,
|
||||||
int tokens, int top_k, float top_p, float temp,
|
int tokens, int top_k, float top_p, float temp,
|
||||||
float repeat_penalty, int repeat_last_n,
|
float repeat_penalty, int repeat_last_n,
|
||||||
bool ignore_eos, bool memory_f16, int n_batch,
|
bool ignore_eos, bool memory_f16, int n_batch,
|
||||||
int n_keep);
|
int n_keep, const char *stopprompt);
|
||||||
|
|
||||||
void llama_free_params(void *params_ptr);
|
void llama_free_params(void *params_ptr);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
|
||||||
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
|
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
|
||||||
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
|
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
|
||||||
C.bool(po.IgnoreEOS), C.bool(po.F16KV),
|
C.bool(po.IgnoreEOS), C.bool(po.F16KV),
|
||||||
C.int(po.Batch), C.int(po.NKeep),
|
C.int(po.Batch), C.int(po.NKeep), C.CString(po.StopPrompt),
|
||||||
)
|
)
|
||||||
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
|
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ type PredictOptions struct {
|
||||||
TopP, Temperature, Penalty float64
|
TopP, Temperature, Penalty float64
|
||||||
F16KV bool
|
F16KV bool
|
||||||
IgnoreEOS bool
|
IgnoreEOS bool
|
||||||
|
StopPrompt string
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictOption func(p *PredictOptions)
|
type PredictOption func(p *PredictOptions)
|
||||||
|
|
@ -38,6 +39,7 @@ var DefaultOptions PredictOptions = PredictOptions{
|
||||||
Repeat: 64,
|
Repeat: 64,
|
||||||
Batch: 8,
|
Batch: 8,
|
||||||
NKeep: 64,
|
NKeep: 64,
|
||||||
|
StopPrompt: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetContext sets the context size.
|
// SetContext sets the context size.
|
||||||
|
|
@ -154,6 +156,13 @@ func SetNKeep(n int) PredictOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStopPrompt sets the prompt to stop generation.
|
||||||
|
func SetStopPrompt(prompt string) PredictOption {
|
||||||
|
return func(p *PredictOptions) {
|
||||||
|
p.StopPrompt = prompt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new PredictOptions object with the given options.
|
// Create a new PredictOptions object with the given options.
|
||||||
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
||||||
p := DefaultOptions
|
p := DefaultOptions
|
||||||
|
|
|
||||||
Reference in a new issue