mirror of https://gitee.com/godoos/godoos.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
2.8 KiB
147 lines
2.8 KiB
package sd
|
|
|
|
import (
|
|
"io"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
func TestNewStableDiffusionAutoModelPredict(t *testing.T) {
|
|
options := DefaultOptions
|
|
t.Log(options)
|
|
model, err := NewAutoModel(options)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer model.Close()
|
|
model.SetLogCallback(func(level LogLevel, msg string) {
|
|
t.Log(msg)
|
|
})
|
|
err = model.LoadFromFile("./models/miniSD.ckpt")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
var writers []io.Writer
|
|
filenames := []string{
|
|
"./assets/love_cat2.png",
|
|
}
|
|
for _, filename := range filenames {
|
|
file, err := os.Create(filename)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
writers = append(writers, file)
|
|
}
|
|
|
|
params := DefaultFullParams
|
|
params.BatchCount = 1
|
|
params.Width = 256
|
|
params.Height = 256
|
|
params.NegativePrompt = ""
|
|
err = model.Predict("british short hair cat, high quality", params, writers)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestModel_ROCm(t *testing.T) {
|
|
options := DefaultOptions
|
|
options.GpuEnable = true
|
|
t.Log(options)
|
|
model, err := NewAutoModel(options)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer model.Close()
|
|
model.SetLogCallback(func(level LogLevel, msg string) {
|
|
t.Log(msg)
|
|
})
|
|
err = model.LoadFromFile("./models/miniSD.ckpt")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
var writers []io.Writer
|
|
filenames := []string{
|
|
"./assets/love_cat2.png",
|
|
}
|
|
for _, filename := range filenames {
|
|
file, err := os.Create(filename)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
writers = append(writers, file)
|
|
}
|
|
|
|
params := DefaultFullParams
|
|
params.BatchCount = 1
|
|
params.Width = 256
|
|
params.Height = 256
|
|
params.NegativePrompt = ""
|
|
err = model.Predict("british short hair cat, high quality", params, writers)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func TestNewStableDiffusionAutoModelImagePredict(t *testing.T) {
|
|
options := DefaultOptions
|
|
options.VaeDecodeOnly = false
|
|
t.Log(options)
|
|
model, err := NewAutoModel(options)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer model.Close()
|
|
model.SetLogCallback(func(level LogLevel, msg string) {
|
|
t.Log(msg)
|
|
})
|
|
err = model.LoadFromFile("./models/mysafetensors")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
inFile, err := os.Open("./assets/love_cat0.png")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer inFile.Close()
|
|
|
|
var writers []io.Writer
|
|
filenames := []string{
|
|
"./assets/love_cat0_m.png",
|
|
//"./assets/love_cat1_m.png",
|
|
//"./assets/love_cat5.png",
|
|
//"./assets/love_cat6.png"
|
|
}
|
|
for _, filename := range filenames {
|
|
file, err := os.Create(filename)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
writers = append(writers, file)
|
|
}
|
|
params := DefaultFullParams
|
|
params.BatchCount = 1
|
|
params.Width = 256
|
|
params.Height = 256
|
|
params.NegativePrompt = ""
|
|
err = model.ImagePredict(inFile, "dogs", params, writers)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
|