60 lines
1.5 KiB
Go
60 lines
1.5 KiB
Go
package llama
|
|
|
|
// #cgo CXXFLAGS: -I./llama.cpp/examples -I./llama.cpp
|
|
// #cgo LDFLAGS: -L./ -lbinding -lm -lstdc++
|
|
// #cgo darwin LDFLAGS: -framework Accelerate
|
|
// #include "binding.h"
|
|
import "C"
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"unsafe"
|
|
)
|
|
|
|
type LLama struct {
|
|
state unsafe.Pointer
|
|
}
|
|
|
|
func New(model string, opts ...ModelOption) (*LLama, error) {
|
|
mo := NewModelOptions(opts...)
|
|
modelPath := C.CString(model)
|
|
result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Parts), C.int(mo.Seed), C.bool(mo.F16Memory), C.bool(mo.MLock))
|
|
if result == nil {
|
|
return nil, fmt.Errorf("failed loading model")
|
|
}
|
|
|
|
return &LLama{state: result}, nil
|
|
}
|
|
|
|
func (l *LLama) Free() {
|
|
C.llama_free_model(l.state)
|
|
}
|
|
|
|
func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
|
|
po := NewPredictOptions(opts...)
|
|
|
|
input := C.CString(text)
|
|
if po.Tokens == 0 {
|
|
po.Tokens = 99999999
|
|
}
|
|
out := make([]byte, po.Tokens)
|
|
|
|
params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
|
|
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
|
|
C.bool(po.IgnoreEOS), C.bool(po.F16KV),
|
|
C.int(po.Batch), C.int(po.NKeep),
|
|
)
|
|
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])))
|
|
if ret != 0 {
|
|
return "", fmt.Errorf("inference failed")
|
|
}
|
|
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
|
|
|
|
res = strings.TrimPrefix(res, " ")
|
|
res = strings.TrimPrefix(res, text)
|
|
res = strings.TrimPrefix(res, "\n")
|
|
|
|
C.llama_free_params(params)
|
|
|
|
return res, nil
|
|
}
|