package imitater import ( "errors" "io/ioutil" "net/http" "os" "reflect" "strings" "474420502.top/eson/structure/circular_linked" yaml "gopkg.in/yaml.v2" ) // YamlCurls 为了自定义序列化函数 type YamlCurls []string // UnmarshalYAML YamlCurls反序列化函数 func (curls *YamlCurls) UnmarshalYAML(unmarshal func(interface{}) error) error { var buf interface{} err := unmarshal(&buf) if err != nil { return nil } switch tbuf := buf.(type) { case string: for _, curlinfo := range parseCurl(tbuf) { *curls = append(*curls, curlinfo) } case []interface{}: for _, ifa := range tbuf { for _, curlinfo := range parseCurl(ifa.(string)) { *curls = append(*curls, curlinfo) } } default: return errors.New("read curls is error, " + reflect.TypeOf(buf).String()) } return nil } // MarshalYAML YamlCurls序列化函数 func (curls *YamlCurls) MarshalYAML() (interface{}, error) { content := "[" for _, curl := range []string(*curls) { content += "\"" + curl + "\"" + ", " } content = strings.TrimRight(content, ", ") content += "]" return content, nil } // YamlProxies 为了自定义序列化函数 type YamlProxies clinked.CircularLinked // UnmarshalYAML YamlProxies反序列化函数 func (proxies *YamlProxies) UnmarshalYAML(unmarshal func(interface{}) error) error { var buf interface{} err := unmarshal(&buf) if err != nil { return nil } switch tbuf := buf.(type) { case string: p := (*clinked.CircularLinked)(proxies) p.Append(tbuf) case []interface{}: p := (*clinked.CircularLinked)(proxies) for _, ifa := range tbuf { p.Append(ifa.(string)) } default: return errors.New("read curls is error, " + reflect.TypeOf(buf).String()) } return nil } // MarshalYAML YamlProxies 序列化函数 func (proxies *YamlProxies) MarshalYAML() (interface{}, error) { content := "[" p := (*clinked.CircularLinked)(proxies) for _, cnode := range p.GetLoopValues() { content += "\"" + cnode.GetValue().(string) + "\"" + ", " } content = strings.TrimRight(content, ", ") content += "]" return content, nil } // ADInfo ad 的一些属性,基础信息等 type ADInfo struct { Priority int `yaml:"priority"` Device string `yaml:"device"` Platform string `yaml:"platform"` AreaCC string `yaml:"area_cc"` Channel int `yaml:"channel"` Media int `yaml:"media"` SpiderID int `yaml:"spider_id"` CatchAccountID int `yaml:"catch_account_id"` } // Config 任务加载的默认配置 type Config struct { // Session int `yaml:"session"` Mode int `yaml:"mode"` Proxies *YamlProxies `yaml:"proxies"` Retry int `yaml:"retry"` Curls YamlCurls `yaml:"curls"` Crontab string `yaml:"crontab"` ITask string `yaml:"task"` ADInfo `yaml:",inline"` } // newDefaultConfig create a default config func newDefaultConfig() *Config { conf := &Config{ // Session: 1, Mode: 0, Retry: 0, Crontab: "", ADInfo: ADInfo{ Device: "", Platform: "", AreaCC: "", Channel: -1, Media: -1, SpiderID: -1, CatchAccountID: -1, }, } return conf } // NewConfig 加载并返回Config func NewConfig(p string) *Config { f, err := os.Open(p) defer f.Close() if err != nil { panic(err) } conf := newDefaultConfig() err = yaml.NewDecoder(f).Decode(conf) if err != nil { panic(err) } return conf } func parseCurl(curl string) []string { var result []string switch curl[0] { case '@': curlfile, err := os.Open(curl[1:]) defer curlfile.Close() if err != nil { panic(err) } curldata, err := ioutil.ReadAll(curlfile) for _, curlinfo := range strings.Split(string(curldata), "\n") { result = append(result, strings.Trim(curlinfo, "\r\n ")) } case '#': resp, err := http.Get(curl[1:]) if err != nil { panic(err) } curldata, err := ioutil.ReadAll(resp.Body) if err != nil { panic(err) } for _, curlinfo := range strings.Split(string(curldata), "\n") { result = append(result, strings.Trim(curlinfo, "\r\n ")) } default: result = append(result, strings.Trim(curl, "\r\n ")) } return result }