This repository has been archived on 2024-01-13. You can view files and clone it, but cannot push or open issues or pull requests.
Kevin/pkg/llama/options.go
2023-04-18 15:00:44 +02:00

173 lines
3.6 KiB
Go

package llama
import "runtime"
type ModelOptions struct {
ContextSize int
Parts int
Seed int
F16Memory bool
MLock bool
}
type PredictOptions struct {
Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int
TopP, Temperature, Penalty float64
F16KV bool
IgnoreEOS bool
StopPrompt string
}
type PredictOption func(p *PredictOptions)
type ModelOption func(p *ModelOptions)
var DefaultModelOptions ModelOptions = ModelOptions{
ContextSize: 512,
Seed: 0,
F16Memory: false,
MLock: false,
}
var DefaultOptions PredictOptions = PredictOptions{
Seed: -1,
Threads: runtime.NumCPU(),
Tokens: 128,
TopK: 10000,
TopP: 0.90,
Temperature: 0.96,
Penalty: 1,
Repeat: 64,
Batch: 8,
NKeep: 64,
StopPrompt: "",
}
// SetContext sets the context size.
func SetContext(c int) ModelOption {
return func(p *ModelOptions) {
p.ContextSize = c
}
}
func SetModelSeed(c int) ModelOption {
return func(p *ModelOptions) {
p.Seed = c
}
}
func SetParts(c int) ModelOption {
return func(p *ModelOptions) {
p.Parts = c
}
}
var EnableF16Memory ModelOption = func(p *ModelOptions) {
p.F16Memory = true
}
var EnableF16KV PredictOption = func(p *PredictOptions) {
p.F16KV = true
}
var EnableMLock ModelOption = func(p *ModelOptions) {
p.MLock = true
}
// Create a new PredictOptions object with the given options.
func NewModelOptions(opts ...ModelOption) ModelOptions {
p := DefaultModelOptions
for _, opt := range opts {
opt(&p)
}
return p
}
var IgnoreEOS PredictOption = func(p *PredictOptions) {
p.IgnoreEOS = true
}
// SetSeed sets the random seed for sampling text generation.
func SetSeed(seed int) PredictOption {
return func(p *PredictOptions) {
p.Seed = seed
}
}
// SetThreads sets the number of threads to use for text generation.
func SetThreads(threads int) PredictOption {
return func(p *PredictOptions) {
p.Threads = threads
}
}
// SetTokens sets the number of tokens to generate.
func SetTokens(tokens int) PredictOption {
return func(p *PredictOptions) {
p.Tokens = tokens
}
}
// SetTopK sets the value for top-K sampling.
func SetTopK(topk int) PredictOption {
return func(p *PredictOptions) {
p.TopK = topk
}
}
// SetTopP sets the value for nucleus sampling.
func SetTopP(topp float64) PredictOption {
return func(p *PredictOptions) {
p.TopP = topp
}
}
// SetTemperature sets the temperature value for text generation.
func SetTemperature(temp float64) PredictOption {
return func(p *PredictOptions) {
p.Temperature = temp
}
}
// SetPenalty sets the repetition penalty for text generation.
func SetPenalty(penalty float64) PredictOption {
return func(p *PredictOptions) {
p.Penalty = penalty
}
}
// SetRepeat sets the number of times to repeat text generation.
func SetRepeat(repeat int) PredictOption {
return func(p *PredictOptions) {
p.Repeat = repeat
}
}
// SetBatch sets the batch size.
func SetBatch(size int) PredictOption {
return func(p *PredictOptions) {
p.Batch = size
}
}
// SetKeep sets the number of tokens from initial prompt to keep.
func SetNKeep(n int) PredictOption {
return func(p *PredictOptions) {
p.NKeep = n
}
}
// SetStopPrompt sets the prompt to stop generation.
func SetStopPrompt(prompt string) PredictOption {
return func(p *PredictOptions) {
p.StopPrompt = prompt
}
}
// Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions
for _, opt := range opts {
opt(&p)
}
return p
}