164 lines
3.4 KiB
Go
164 lines
3.4 KiB
Go
package llama
|
|
|
|
import "runtime"
|
|
|
|
type ModelOptions struct {
|
|
ContextSize int
|
|
Parts int
|
|
Seed int
|
|
F16Memory bool
|
|
MLock bool
|
|
}
|
|
|
|
type PredictOptions struct {
|
|
Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int
|
|
TopP, Temperature, Penalty float64
|
|
F16KV bool
|
|
IgnoreEOS bool
|
|
}
|
|
|
|
type PredictOption func(p *PredictOptions)
|
|
type ModelOption func(p *ModelOptions)
|
|
|
|
var DefaultModelOptions ModelOptions = ModelOptions{
|
|
ContextSize: 512,
|
|
Seed: 0,
|
|
F16Memory: false,
|
|
MLock: false,
|
|
}
|
|
|
|
var DefaultOptions PredictOptions = PredictOptions{
|
|
Seed: -1,
|
|
Threads: runtime.NumCPU(),
|
|
Tokens: 128,
|
|
TopK: 10000,
|
|
TopP: 0.90,
|
|
Temperature: 0.96,
|
|
Penalty: 1,
|
|
Repeat: 64,
|
|
Batch: 8,
|
|
NKeep: 64,
|
|
}
|
|
|
|
// SetContext sets the context size.
|
|
func SetContext(c int) ModelOption {
|
|
return func(p *ModelOptions) {
|
|
p.ContextSize = c
|
|
}
|
|
}
|
|
|
|
func SetModelSeed(c int) ModelOption {
|
|
return func(p *ModelOptions) {
|
|
p.Seed = c
|
|
}
|
|
}
|
|
|
|
func SetParts(c int) ModelOption {
|
|
return func(p *ModelOptions) {
|
|
p.Parts = c
|
|
}
|
|
}
|
|
|
|
var EnableF16Memory ModelOption = func(p *ModelOptions) {
|
|
p.F16Memory = true
|
|
}
|
|
|
|
var EnableF16KV PredictOption = func(p *PredictOptions) {
|
|
p.F16KV = true
|
|
}
|
|
|
|
var EnableMLock ModelOption = func(p *ModelOptions) {
|
|
p.MLock = true
|
|
}
|
|
|
|
// Create a new PredictOptions object with the given options.
|
|
func NewModelOptions(opts ...ModelOption) ModelOptions {
|
|
p := DefaultModelOptions
|
|
for _, opt := range opts {
|
|
opt(&p)
|
|
}
|
|
return p
|
|
}
|
|
|
|
var IgnoreEOS PredictOption = func(p *PredictOptions) {
|
|
p.IgnoreEOS = true
|
|
}
|
|
|
|
// SetSeed sets the random seed for sampling text generation.
|
|
func SetSeed(seed int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Seed = seed
|
|
}
|
|
}
|
|
|
|
// SetThreads sets the number of threads to use for text generation.
|
|
func SetThreads(threads int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Threads = threads
|
|
}
|
|
}
|
|
|
|
// SetTokens sets the number of tokens to generate.
|
|
func SetTokens(tokens int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Tokens = tokens
|
|
}
|
|
}
|
|
|
|
// SetTopK sets the value for top-K sampling.
|
|
func SetTopK(topk int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.TopK = topk
|
|
}
|
|
}
|
|
|
|
// SetTopP sets the value for nucleus sampling.
|
|
func SetTopP(topp float64) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.TopP = topp
|
|
}
|
|
}
|
|
|
|
// SetTemperature sets the temperature value for text generation.
|
|
func SetTemperature(temp float64) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Temperature = temp
|
|
}
|
|
}
|
|
|
|
// SetPenalty sets the repetition penalty for text generation.
|
|
func SetPenalty(penalty float64) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Penalty = penalty
|
|
}
|
|
}
|
|
|
|
// SetRepeat sets the number of times to repeat text generation.
|
|
func SetRepeat(repeat int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Repeat = repeat
|
|
}
|
|
}
|
|
|
|
// SetBatch sets the batch size.
|
|
func SetBatch(size int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.Batch = size
|
|
}
|
|
}
|
|
|
|
// SetKeep sets the number of tokens from initial prompt to keep.
|
|
func SetNKeep(n int) PredictOption {
|
|
return func(p *PredictOptions) {
|
|
p.NKeep = n
|
|
}
|
|
}
|
|
|
|
// Create a new PredictOptions object with the given options.
|
|
func NewPredictOptions(opts ...PredictOption) PredictOptions {
|
|
p := DefaultOptions
|
|
for _, opt := range opts {
|
|
opt(&p)
|
|
}
|
|
return p
|
|
}
|