215 lines
4.3 KiB
Go
215 lines
4.3 KiB
Go
package imitater
|
|
|
|
import (
|
|
"errors"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"474420502.top/eson/structure/circular_linked"
|
|
|
|
yaml "gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// YamlCurls 为了自定义序列化函数
|
|
type YamlCurls []string
|
|
|
|
// UnmarshalYAML YamlCurls反序列化函数
|
|
func (curls *YamlCurls) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
|
|
var buf interface{}
|
|
err := unmarshal(&buf)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
switch tbuf := buf.(type) {
|
|
case string:
|
|
if tbuf != "" {
|
|
for _, curlinfo := range parseCurl(tbuf) {
|
|
*curls = append(*curls, curlinfo)
|
|
}
|
|
}
|
|
|
|
case []interface{}:
|
|
for _, ifa := range tbuf {
|
|
curlstr := ifa.(string)
|
|
if curlstr != "" {
|
|
for _, curlinfo := range parseCurl(curlstr) {
|
|
*curls = append(*curls, curlinfo)
|
|
}
|
|
}
|
|
|
|
}
|
|
default:
|
|
return errors.New("read curls is error, " + reflect.TypeOf(buf).String())
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML YamlCurls序列化函数
|
|
func (curls *YamlCurls) MarshalYAML() (interface{}, error) {
|
|
content := "["
|
|
for _, curl := range []string(*curls) {
|
|
content += "\"" + curl + "\"" + ", "
|
|
}
|
|
content = strings.TrimRight(content, ", ")
|
|
content += "]"
|
|
return content, nil
|
|
}
|
|
|
|
// YamlProxies 为了自定义序列化函数
|
|
type YamlProxies clinked.CircularLinked
|
|
|
|
// UnmarshalYAML YamlProxies反序列化函数
|
|
func (proxies *YamlProxies) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|
|
|
var buf interface{}
|
|
err := unmarshal(&buf)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
switch tbuf := buf.(type) {
|
|
case string:
|
|
p := (*clinked.CircularLinked)(proxies)
|
|
p.Append(tbuf)
|
|
case []interface{}:
|
|
p := (*clinked.CircularLinked)(proxies)
|
|
for _, ifa := range tbuf {
|
|
p.Append(ifa.(string))
|
|
}
|
|
default:
|
|
return errors.New("read curls is error, " + reflect.TypeOf(buf).String())
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MarshalYAML YamlProxies 序列化函数
|
|
func (proxies *YamlProxies) MarshalYAML() (interface{}, error) {
|
|
content := "["
|
|
p := (*clinked.CircularLinked)(proxies)
|
|
|
|
for _, cnode := range p.GetLoopValues() {
|
|
content += "\"" + cnode.GetValue().(string) + "\"" + ", "
|
|
}
|
|
content = strings.TrimRight(content, ", ")
|
|
content += "]"
|
|
return content, nil
|
|
}
|
|
|
|
// ADInfo ad 的一些属性,基础信息等
|
|
type ADInfo struct {
|
|
Priority int `yaml:"priority"`
|
|
Device string `yaml:"device"`
|
|
Platform string `yaml:"platform"`
|
|
AreaCC string `yaml:"area_cc"`
|
|
Channel int `yaml:"channel"`
|
|
Media int `yaml:"media"`
|
|
SpiderID int `yaml:"spider_id"`
|
|
CatchAccountID int `yaml:"catch_account_id"`
|
|
}
|
|
|
|
// Config 任务加载的默认配置
|
|
type Config struct {
|
|
// Session int `yaml:"session"`
|
|
Mode int `yaml:"mode"`
|
|
Proxies *YamlProxies `yaml:"proxies"`
|
|
Retry int `yaml:"retry"`
|
|
|
|
Curls YamlCurls `yaml:"curls"`
|
|
|
|
Crontab string `yaml:"crontab"`
|
|
ITask string `yaml:"task"`
|
|
|
|
ADInfo `yaml:",inline"`
|
|
}
|
|
|
|
// newDefaultConfig create a default config
|
|
func newDefaultConfig() *Config {
|
|
conf := &Config{
|
|
// Session: 1,
|
|
Mode: 0,
|
|
Retry: 0,
|
|
Crontab: "",
|
|
|
|
ADInfo: ADInfo{
|
|
Device: "",
|
|
Platform: "",
|
|
AreaCC: "",
|
|
Channel: -1,
|
|
Media: -1,
|
|
SpiderID: -1,
|
|
CatchAccountID: -1,
|
|
},
|
|
}
|
|
return conf
|
|
}
|
|
|
|
// NewConfig 加载并返回Config
|
|
func NewConfig(p string) *Config {
|
|
f, err := os.Open(p)
|
|
defer f.Close()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
conf := newDefaultConfig()
|
|
err = yaml.NewDecoder(f).Decode(conf)
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return conf
|
|
}
|
|
|
|
func parseCurl(curl string) []string {
|
|
var result []string
|
|
switch curl[0] {
|
|
case '@':
|
|
curlfile, err := os.Open(curl[1:])
|
|
defer curlfile.Close()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
curldata, err := ioutil.ReadAll(curlfile)
|
|
for _, curlinfo := range strings.Split(string(curldata), "\n") {
|
|
curlstr := strings.Trim(curlinfo, "\r\n ")
|
|
if len(curlstr) >= 4 {
|
|
result = append(result, curlstr)
|
|
}
|
|
}
|
|
|
|
case '#':
|
|
resp, err := http.Get(curl[1:])
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
curldata, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
for _, curlinfo := range strings.Split(string(curldata), "\n") {
|
|
curlstr := strings.Trim(curlinfo, "\r\n ")
|
|
if len(curlstr) >= 4 {
|
|
result = append(result, curlstr)
|
|
}
|
|
}
|
|
default:
|
|
curlstr := strings.Trim(curl, "\r\n ")
|
|
if len(curlstr) >= 4 {
|
|
result = append(result, curlstr)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|