• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Go jaegerde 应用【logger+gorm+grpc+http】

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

在以前的Go语言jaeger和opentracing 有用来做日志,但是很多时候我们希望数据库的操作也可以记录下来,程序一般作为http或者grpc 服务, 所以grpc和http也是需要用中间件来实现的。首先看程序的目录, 只是一个简单的demo:

 因为程序最后会部署到k8s上,计划采用docker来收集,灌到elk或者graylog,所以这里直接输出,程序设计采用切换数据库 实现简单的saas。

来看看主要的几个文件

logger.go

package logger

import (
    "context"
    "fmt"
    "io"
    "runtime"
    "strings"
    "time"

    "github.com/opentracing/opentracing-go"
    "github.com/uber/jaeger-client-go"
    "github.com/uber/jaeger-client-go/config"
    "github.com/uber/jaeger-client-go/log"
    "github.com/uber/jaeger-lib/metrics"
    "go.uber.org/zap"
    "go.uber.org/zap/zapcore"
)

var (
    logTimeFormat = "2006-01-02T15:04:05.000+08:00"
    zapLogger     *zap.Logger
)

//配置默认初始化
func init() {
    c := zap.NewProductionConfig()
    c.EncoderConfig.LevelKey = ""
    c.EncoderConfig.CallerKey = ""
    c.EncoderConfig.MessageKey = "logModel"
    c.EncoderConfig.TimeKey = ""
    c.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
    zapLogger, _ = c.Build()
}

//初始化 Jaeger client
func NewJaegerTracer(serviceName string, agentHost string) (tracer opentracing.Tracer, closer io.Closer, err error) {
    cfg := config.Configuration{
        ServiceName: serviceName,
        Sampler: &config.SamplerConfig{
            Type:  jaeger.SamplerTypeRateLimiting,
            Param: 10,
        },
        Reporter: &config.ReporterConfig{
            LogSpans:            false,
            BufferFlushInterval: 1 * time.Second,
            LocalAgentHostPort:  agentHost,
        },
    }

    jLogger := log.StdLogger
    jMetricsFactory := metrics.NullFactory

    tracer, closer, err = cfg.NewTracer(config.Logger(jLogger), config.Metrics(jMetricsFactory))
    if err == nil {
        opentracing.SetGlobalTracer(tracer)
    }

    return tracer, closer, err
}

func Error(ctx context.Context, format interface{}, args ...interface{}) {
    msg := ""
    if e, ok := format.(error); ok {
        msg = fmt.Sprintf(e.Error(), args...)
    } else if e, ok := format.(string); ok {
        msg = fmt.Sprintf(e, args...)
    }

    jsonStdOut(ctx, zap.ErrorLevel, msg)
}

func Warn(ctx context.Context, format string, args ...interface{}) {
    jsonStdOut(ctx, zap.WarnLevel, fmt.Sprintf(format, args...))
}

func Info(ctx context.Context, format string, args ...interface{}) {
    jsonStdOut(ctx, zap.InfoLevel, fmt.Sprintf(format, args...))
}

func Debug(ctx context.Context, format string, args ...interface{}) {
    jsonStdOut(ctx, zap.DebugLevel, fmt.Sprintf(format, args...))
}

//本地打印 Json
func jsonStdOut(ctx context.Context, level zapcore.Level, msg string) {
    traceId, spanId := getTraceId(ctx)
    if ce := zapLogger.Check(level, "zap"); ce != nil {
        ce.Write(
            zap.Any("message", JsonLogger{
                LogTime:  time.Now().Format(logTimeFormat),
                Level:    level,
                Content:  msg,
                CallPath: getCallPath(),
                TraceId:  traceId,
                SpanId:   spanId,
            }),
        )
    }
}

type JsonLogger struct {
    TraceId  string        `json:"traceId"`
    SpanId   uint64        `json:"spanId"`
    Content  interface{}   `json:"content"`
    CallPath interface{}   `json:"callPath"`
    LogTime  string        `json:"logDate"` //日志时间
    Level    zapcore.Level `json:"level"`   //日志级别
}

func getTraceId(ctx context.Context) (string, uint64) {
    span := opentracing.SpanFromContext(ctx)
    if span == nil {
        return "", 0
    }

    if sc, ok := span.Context().(jaeger.SpanContext); ok {

        return fmt.Sprintf("%v", sc.TraceID()), uint64(sc.SpanID())
    }
    return "", 0
}

func getCallPath() string {
    _, file, lineno, ok := runtime.Caller(2)
    if ok {
        return strings.Replace(fmt.Sprintf("%s:%d", stringTrim(file, ""), lineno), "%2e", ".", -1)
    }
    return ""
}

func stringTrim(s, cut string) string {
    ss := strings.SplitN(s, cut, 2)
    if len(ss) == 1 {
        return ss[0]
    }
    return ss[1]
}

db.go

package db

import (
    "context"
    "database/sql/driver"
    "fmt"
    "net/url"
    "reflect"
    "regexp"
    "strings"
    "tracedemo/logger"
    "unicode"

    "github.com/jinzhu/gorm"
    "github.com/pkg/errors"

    "sync"

    "time"

    _ "github.com/go-sql-driver/mysql"
    "github.com/opentracing/opentracing-go"
)

// DB连接配置信息
type Config struct {
    DbHost string
    DbPort int
    DbUser string
    DbPass string
    DbName string
    Debug  bool
}

// 连接的数据库类型
const (
    dbMaster         string = "master"
    jaegerContextKey        = "jeager:context"
    callbackPrefix          = "jeager"
    startTime               = "start:time"
)

func init() {
    connMap = make(map[string]*gorm.DB)
}

var (
    connMap  map[string]*gorm.DB
    connLock sync.RWMutex
)

// 初始化DB
func InitDb(siteCode string, cfg *Config) (err error) {
    url := url.Values{}
    url.Add("parseTime", "True")
    url.Add("loc", "Local")
    url.Add("charset", "utf8mb4")
    url.Add("collation", "utf8mb4_unicode_ci")
    url.Add("readTimeout", "0s")
    url.Add("writeTimeout", "0s")
    url.Add("timeout", "0s")

    dsn := fmt.Sprintf("%s:%s@tcp(%s:%v)/%s?%s", cfg.DbUser, cfg.DbPass, cfg.DbHost, cfg.DbPort, cfg.DbName, url.Encode())

    conn, err := gorm.Open("mysql", dsn)
    if err != nil {
        return errors.Wrap(err, "fail to connect db")
    }

    //新增gorm插件
    if cfg.Debug == true {
        registerCallbacks(conn)
    }
    //打印日志
    //conn.LogMode(true)

    conn.DB().SetMaxIdleConns(30)
    conn.DB().SetMaxOpenConns(200)
    conn.DB().SetConnMaxLifetime(60 * time.Second)

    if err := conn.DB().Ping(); err != nil {
        return errors.Wrap(err, "fail to ping db")
    }

    connLock.Lock()
    dbName := fmt.Sprintf("%s-%s", siteCode, dbMaster)
    connMap[dbName] = conn
    connLock.Unlock()

    go mysqlHeart(conn)

    return nil
}

func GetMaster(ctx context.Context) *gorm.DB {
    connLock.RLock()
    defer connLock.RUnlock()

    siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode"))
    if strings.Contains(siteCode, "nil") {
        panic(errors.New("当前上下文没有找到DB"))
    }

    dbName := fmt.Sprintf("%s-%s", siteCode, dbMaster)

    ctx = context.WithValue(ctx, "DbName", dbName)

    db := connMap[dbName]
    if db == nil {
        panic(errors.New(fmt.Sprintf("当前上下文没有找到DB:%s", dbName)))
    }

    return db.Set(jaegerContextKey, ctx)
}

func mysqlHeart(conn *gorm.DB) {
    for {
        if conn != nil {
            err := conn.DB().Ping()
            if err != nil {
                fmt.Println(fmt.Sprintf("mysqlHeart has err:%v", err))
            }
        }

        time.Sleep(3 * time.Minute)
    }
}

func registerCallbacks(db *gorm.DB) {
    driverName := db.Dialect().GetName()
    switch driverName {
    case "postgres":
        driverName = "postgresql"
    }
    spanTypePrefix := fmt.Sprintf("gorm.db.%s.", driverName)
    querySpanType := spanTypePrefix + "query"
    execSpanType := spanTypePrefix + "exec"

    type params struct {
        spanType  string
        processor func() *gorm.CallbackProcessor
    }
    callbacks := map[string]params{
        "gorm:create": {
            spanType:  execSpanType,
            processor: func() *gorm.CallbackProcessor { return db.Callback().Create() },
        },
        "gorm:delete": {
            spanType:  execSpanType,
            processor: func() *gorm.CallbackProcessor { return db.Callback().Delete() },
        },
        "gorm:query": {
            spanType:  querySpanType,
            processor: func() *gorm.CallbackProcessor { return db.Callback().Query() },
        },
        "gorm:update": {
            spanType:  execSpanType,
            processor: func() *gorm.CallbackProcessor { return db.Callback().Update() },
        },
        "gorm:row_query": {
            spanType:  querySpanType,
            processor: func() *gorm.CallbackProcessor { return db.Callback().RowQuery() },
        },
    }
    for name, params := range callbacks {
        params.processor().Before(name).Register(
            fmt.Sprintf("%s:before:%s", callbackPrefix, name),
            newBeforeCallback(params.spanType),
        )
        params.processor().After(name).Register(
            fmt.Sprintf("%s:after:%s", callbackPrefix, name),
            newAfterCallback(),
        )
    }
}

func newBeforeCallback(spanType string) func(*gorm.Scope) {
    return func(scope *gorm.Scope) {
        ctx, ok := scopeContext(scope)
        if !ok {
            return
        }
        //新增链路追踪
        span, ctx := opentracing.StartSpanFromContext(ctx, spanType)
        if span.Tracer() == nil {
            span.Finish()
            ctx = nil
        }
        scope.Set(jaegerContextKey, ctx)
        scope.Set(startTime, time.Now().UnixNano())
    }
}

func newAfterCallback() func(*gorm.Scope) {
    return func(scope *gorm.Scope) {
        ctx, ok := scopeContext(scope)
        if !ok {
            return
        }
        span := opentracing.SpanFromContext(ctx)
        if span == nil {
            return
        }
        defer span.Finish()

        duration := int64(0)
        if t, ok := scopeStartTime(scope); ok {
            duration = (time.Now().UnixNano() - t) / 1e6
        }

        logger.Debug(ctx, "[gorm] [%vms] [RowsReturned(%v)] %v  ", duration, scope.DB().RowsAffected, gormSQL(scope.SQL, scope.SQLVars))

        for _, err := range scope.DB().GetErrors() {
            if gorm.IsRecordNotFoundError(err) || err == errors.New("sql: no rows in result set") {
                continue
            }
            //打印错误日志
            logger.Error(ctx, "%v", err.Error())
        }
        //span.LogFields(traceLog.String("sql", scope.SQL))
    }
}

func scopeContext(scope *gorm.Scope) (context.Context, bool) {
    value, ok := scope.Get(jaegerContextKey)
    if !ok {
        return nil, false
    }
    ctx, _ := value.(context.Context)
    return ctx, ctx != nil
}

func scopeStartTime(scope *gorm.Scope) (int64, bool) {
    value, ok := scope.Get(startTime)
    if !ok {
        return 0, false
    }
    t, ok := value.(int64)
    return t, ok
}

/*===============Log=======================================*/
var (
    sqlRegexp                = regexp.MustCompile(`\?`)
    numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
)

func gormSQL(inputSql interface{}, value interface{}) string {
    var sql string
    var formattedValues []string
    for _, value := range value.([]interface{}) {
        indirectValue := reflect.Indirect(reflect.ValueOf(value))
        if indirectValue.IsValid() {
            value = indirectValue.Interface()
            if t, ok := value.(time.Time); ok {
                if t.IsZero() {
                    formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00"))
                } else {
                    formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
                }
            } else if b, ok := value.([]byte); ok {
                if str := string(b); isPrintable(str) {
                    formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
                } else {
                    formattedValues = append(formattedValues, "'<binary>'")
                }
            } else if r, ok := value.(driver.Valuer); ok {
                if value, err := r.Value(); err == nil && value != nil {
                    formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
                } else {
                    formattedValues = append(formattedValues, "NULL")
                }
            } else {
                switch value.(type) {
                case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool:
                    formattedValues = append(formattedValues, fmt.Sprintf("%v", value))
                default:
                    formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
                }
            }
        } else {
            formattedValues = append(formattedValues, "NULL")
        }
    }

    if formattedValues == nil || len(formattedValues) < 1 {
        return sql
    }

    // differentiate between $n placeholders or else treat like ?
    if numericPlaceHolderRegexp.MatchString(inputSql.(string)) {
        sql = inputSql.(string)
        for index, value := range formattedValues {
            placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
            sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
        }
    } else {
        formattedValuesLength := len(formattedValues)
        for index, value := range sqlRegexp.Split(inputSql.(string), -1) {
            sql += value
            if index < formattedValuesLength {
                sql += formattedValues[index]
            }
        }
    }

    return sql
}

func isPrintable(s string) bool {
    for _, r := range s {
        if !unicode.IsPrint(r) {
            return false
        }
    }
    return true
}

server.go

package apiserver

import (
    contextV2 "context"
    "fmt"
    "runtime/debug"
    "tracedemo/apiserver/userinfo"
    "tracedemo/logger"

    "github.com/kataras/iris/v12"
    "github.com/kataras/iris/v12/context"
    "github.com/opentracing/opentracing-go"
)

func StartApiServerr() {
    addr := ":8080"

    app := iris.New()
    app.Use(openTracing())
    app.Use(withSiteCode())
    app.Use(withRecover())

    app.Get("/", func(c context.Context) {
        c.WriteString("pong")
    })

    initIris(app)
    logger.Info(contextV2.Background(),  "[apiServer]开始监听%s,", addr)

    err := app.Run(iris.Addr(addr), iris.WithoutInterruptHandler)
    if err != nil {
        logger.Error(contextV2.Background(), "[apiServer]开始监听%s 错误%v,", addr,err)
    }
}

func initIris(app *iris.Application) {
   api:= userinfo.ApiServer{}
    userGroup := app.Party("/user")
    {
        userGroup.Get("/test",api.TestUserInfo)
        userGroup.Get("/rpc",api.TestRpc)
    }
}

func openTracing() context.Handler {
    return func(c iris.Context) {
        span := opentracing.GlobalTracer().StartSpan("apiServer")
        c.ResetRequest(c.Request().WithContext(opentracing.ContextWithSpan(c.Request().Context(), span)))
        logger.Info(c.Request().Context(), "Api请求地址%v", c.Request().URL)
        c.Next()
    }
}

func withSiteCode() context.Handler {
    return func(c iris.Context) {
        siteCode := c.GetHeader("SiteCode")
        if len(siteCode) < 1 {
            siteCode = "001"
        }
        ctx := contextV2.WithValue(c.Request().Context(), "SiteCode", siteCode)
        c.ResetRequest(c.Request().WithContext(ctx))

        c.Next()
    }
}

func withRecover() context.Handler {
    return func(c iris.Context) {
        defer func() {
            if e := recover(); e != nil {
                stack := debug.Stack()
                logger.Error(c.Request().Context(), fmt.Sprintf("Api has err:%v, stack:%v", e, string(stack)))
            }
        }()

        c.Next()
    }
}

grpc的中间件middleware.go

package middleware

import (
    "context"
    "encoding/json"
    "fmt"
    "github.com/opentracing/opentracing-go"
    "github.com/opentracing/opentracing-go/ext"
    "google.golang.org/grpc"
    "google.golang.org/grpc/metadata"
    "runtime/debug"
    "strings"
    "time"
    "tracedemo/logger"
)

type MDCarrier struct {
    metadata.MD
}

func (m MDCarrier) ForeachKey(handler func(key, val string) error) error {
    for k, strs := range m.MD {
        for _, v := range strs {
            if err := handler(k, v); err != nil {
                return err
            }
        }
    }
    return nil
}

func (m MDCarrier) Set(key, val string) {
    m.MD[key] = append(m.MD[key], val)
}

// ClientInterceptor 客户端拦截器
func ClientTracing(tracer opentracing.Tracer) grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, request, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        //一个RPC调用的服务端的span,和RPC服务客户端的span构成ChildOf关系
        var parentCtx opentracing.SpanContext
        parentSpan := opentracing.SpanFromContext(ctx)
        if parentSpan != nil {
            parentCtx = parentSpan.Context()
        }
        span := tracer.StartSpan(
            method,
            opentracing.ChildOf(parentCtx),
            opentracing.Tag{Key: string(ext.Component), Value: "gRPC Client"},
            ext.SpanKindRPCClient,
        )

        defer span.Finish()
        md, ok := metadata.FromOutgoingContext(ctx)
        if !ok {
            md = metadata.New(nil)
        } else {
            md = md.Copy()
        }

        err := tracer.Inject(
            span.Context(),
            opentracing.TextMap,
            MDCarrier{md}, // 自定义 carrier
        )

        if err != nil {
            logger.Error(ctx, "ClientTracing inject span error :%v", err.Error())
        }

        ///SiteCode
        siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode"))
        if len(siteCode) < 1 || strings.Contains(siteCode, "nil") {
            siteCode = "001"
        }
        md.Set("SiteCode", siteCode)
        //
        newCtx := metadata.NewOutgoingContext(ctx, md)
        err = invoker(newCtx, method, request, reply, cc, opts...)

        if err != nil {
            logger.Error(ctx, "ClientTracing call error : %v", err.Error())
        }
        return err
    }
}

func ClientSiteCode() grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, request, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        md, ok := metadata.FromOutgoingContext(ctx)
        if !ok {
            md = metadata.New(nil)
        } else {
            md = md.Copy()
        }

        ///SiteCode
        siteCode := fmt.Sprintf("%v", ctx.Value("SiteCode"))
        if len(siteCode) < 1 || strings.Contains(siteCode, "nil") {
            siteCode = "< 

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Go Lang发布时间:2022-07-10
下一篇:
Go语言学习之匿名函数发布时间:2022-07-10
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap