Go fiber custom middleware

Go-custom-middleware

As a Golang developer coming from a nodejs background, my framework of choice when it comes to REST API like many is Go Fiber. Due to its simplicity and beginner-friendly approach like express.js, it becomes very easy to start and build performant APIs. Also, just like expressjs, it comes with middleware support, routing, template engines etc built-in. Not only that it also provides many middlewares like limiter, compress etc. out of the box, ready to use in our projects.
Here we will not be digging into any of the built-in middlewares. As I am sure there are already a good number of articles regarding the same on the internet. We will be building a custom go fiber middleware on which, currently there aren’t that many resources available.

Agenda

To keep things simple and easy to understand for demonstration purposes, we will be creating a custom JWT middleware. We will then use the same in a demo go fiber API. It will be something similar to what we have already implemented in one of our JWT implementation tutorials in nodejs. The middleware will let requests with a valid JWT token pass while blocking invalid ones. So let’s see how we do it.

Note: this isn’t a JWT tutorial so we are taking a basic example without digging into proper JWT implementation for now.

Initialize a demo go fiber API project

We will need to set up a demo go fiber API to be able to test our custom middleware.

go mod init my-project
go get -u github.com/gofiber/fiber/v2

After that, we need to install a jwt package to generate and verify jwt tokens. We will be using this in our custom jwt middleware.

go get github.com/dgrijalva/jwt-go

Project go fiber directory structure

go-middleware-tutorial/
┣ jwtchecker/
┃ ┗ jwtchecker.go
┣ go.mod
┣ go.sum
┗ main.go

Under our root directory, we create another directory /jwtchecker inside which our custom middleware file jwtchecker.go will go. Also main.go is our API entry file.

main.go

package main

import (
    // import our middleware(which we will create shortly)
	"go-middleware/jwtchecker"
	"log"

	"github.com/dgrijalva/jwt-go"
	"github.com/gofiber/fiber/v2"
)

func main() {
	app := fiber.New()

    // route to generate jwt tokens
	app.Get("/generate-token", func(c *fiber.Ctx) error {
		// here we handle token generation
	})

    // here we will use our middleware
	app.Use(/*here we use our custom jwt middleware*/)

	// our JWT protected route
	app.Get("/protected", func(c *fiber.Ctx) error {
		// code here will only be executed if we have a valid jwt token
	})

	log.Fatal(app.Listen(":3000"))
}

Create our go fiber middleware

Below will be the structure of our go fiber custom jwtchecker middleware.

package jwtchecker

import (
	"errors"
	"fmt"
	"time"

	"github.com/dgrijalva/jwt-go"
	"github.com/gofiber/fiber/v2"
)

/* 
	REQUIRED(Any middleware must have this)

	For every middleware we need a config.
	In config we also need to define a function which allows us to skip the middleware if return true.
	By convention it should be named as "Filter" but any other name will work too.
*/
type Config struct {
	Filter         func(c *fiber.Ctx) bool // Required
	Unauthorized fiber.Handler  // middleware specfic
	Decode       func(c *fiber.Ctx) (*jwt.MapClaims, error) // middleware specfic
	Secret       string // middleware specfic
	Expiry       int64 // middleware specfic
}

/*
	Middleware specific

	Our middleware's config default values if not passed
*/
var ConfigDefault = Config{
	Filter:       nil,
	Decode:       nil,
	Unauthorized: nil,
	Secret:       "secret",
	Expiry:       60,
}

/*
	Middleware specific

	Function for generating default config
*/
func configDefault(config ...Config) Config {
}

/*
	Middleware specific
	
	Function to generate a jwt token
*/
func Encode(claims *jwt.MapClaims, secret string, expiryAfter int64) (string, error) {
}

/*
    REQUIRED(Any middleware must have this)

	Our main middleware function used to initialize our middleware.
	By convention we name it "New" but any other name will work too.
*/
func New(config Config) fiber.Handler {
}

Config

Our custom JWT middleware has the below config. Except for the “Filter” property, every other property in the config will change from middleware to middleware.

type Config struct {
	// when returned true, our middleware is skipped
	Filter         func(c *fiber.Ctx) bool

	// function to run when there is error decoding jwt
	Unauthorized fiber.Handler

	// function to decode our jwt token
	Decode       func(c *fiber.Ctx) (*jwt.MapClaims, error)

	// set jwt secret
	Secret       string

	// set jwt expiry in seconds
	Expiry       int64
}

configDefault

This function is not that necessary as everything done here, can be done in our New function itself. But to make sure our New function isn’t cluttered we have decided to separate it into another function. Here we are assigning default values to our config struct passed into our New function during middleware initialization. All of this is evident from the code itself as shown below. So we will not line by line here. But the most important part of this function is defining our default jwt Decode function.

The Decode function extracts the jwt token from the Authorization header and parses and verifies our jwt token using our secret. If successful it returns a pointer to our parsed claim else returns an error.

func configDefault(config ...Config) Config {
	// Return default config if nothing provided
	if len(config) < 1 {
		return ConfigDefault
	}

	// Override default config
	cfg := config[0]

	// Set default values if not passed
	if cfg.Filter == nil {
		cfg.Filter = ConfigDefault.Filter
	}

	// Set default secret if not passed
	if cfg.Secret == "" {
		cfg.Secret = ConfigDefault.Secret
	}

	// Set default expiry if not passed
	if cfg.Expiry == 0 {
		cfg.Expiry = ConfigDefault.Expiry
	}

        // this is the main jwt decode function of our middleware
	if cfg.Decode == nil {
		// Set default Decode function if not passed
		cfg.Decode = func(c *fiber.Ctx) (*jwt.MapClaims, error) {

			authHeader := c.Get("Authorization")

			if authHeader == "" {
				return nil, errors.New("Authorization header is required")
			}
                        
                        // we parse our jwt token and check for validity against our secret
			token, err := jwt.Parse(
				authHeader[7:],
				func(token *jwt.Token) (interface{}, error) {
					// verifying our algo
					if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
						return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
					}
					return []byte(cfg.Secret), nil
				},
			)

			if err != nil {
				return nil, errors.New("Error parsing token")
			}

			claims, ok := token.Claims.(jwt.MapClaims)

			if !(ok && token.Valid) {
				return nil, errors.New("Invalid token")
			}

			if expiresAt, ok := claims["exp"]; ok && int64(expiresAt.(float64)) < time.Now().UTC().Unix() {
				return nil, errors.New("jwt is expired")
			}

			return &claims, nil
		}
	}

	// Set default Unauthorized if not passed
	if cfg.Unauthorized == nil {
		cfg.Unauthorized = func(c *fiber.Ctx) error {
			return c.SendStatus(fiber.StatusUnauthorized)
		}
	}

	return cfg
}

New – go fiber custom middleware main function

This will be the main middleware function of our go fiber custom middleware. We have named it “New” following convention but, you can name it something else too. The New function expects a Config as a parameter and returns a fiber.handler. As we can see at the very beginning we have called our configDefault function which we have discussed above.

Inside our fiber.Handler we also need to check for our Filter property in Config. This is required by convention so that we have the option to skip middleware if Filter returns true. After that, we just call our Decode function and set our decoded claim using Local that stores variables scoped to the request. In case of any error, we return the Unauthorized status code else we let the request pass.

func New(config Config) fiber.Handler {

	// For setting default config
	cfg := configDefault(config)

	return func(c *fiber.Ctx) error {
		// Don't execute middleware if Filter returns true
		if cfg.Filter != nil && cfg.Filter(c) {
			fmt.Println("Midddle was skipped")
			return c.Next()
		}
		fmt.Println("Midddle was run")

		claims, err := cfg.Decode(c)

		if err == nil {
			c.Locals("jwtClaims", *claims)
			return c.Next()
		}

		return cfg.Unauthorized(c)
	}
}

Encode

We also need a function that allows us to generate a jwt token. That is what Encode does for us here. It takes in a pointer of type jwt.MapClaims and an expiryAfter(in seconds) as arguments and returns a signed jwt token string.

func Encode(claims *jwt.MapClaims, expiryAfter int64) (string, error) {

	// setting default expiryAfter
	if expiryAfter == 0 {
		expiryAfter = ConfigDefault.Expiry
	}
        
        // or you can use time.Now().Add(time.Second * time.Duration(expiryAfter)).UTC().Unix()
	(*claims)["exp"] = time.Now().UTC().Unix() + expiryAfter

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

	// our signed jwt token string
	signedToken, err := token.SignedString([]byte(ConfigDefault.Secret))

	if err != nil {
		return "", errors.New("Error creating a token")
	}

	return signedToken, nil
}

With that, we are done with our jwt middleware code. Now the only thing left is seeing our middleware in action. So let’s begin setting up our dummy API.

Setup our API for testing

Our dummy API will have the below routes

  • generate-token – it does exactly what its name implies.
  • protected – our jwt protected route.

/generate-token

Here we are using our middleware’s Encode function to generate a jwt token string.

app.Get("/generate-token", func(c *fiber.Ctx) error {
		token, err := jwtchecker.Encode(&jwt.MapClaims{
			"email": "[email protected]",
			"role":  "admin",
			"id":    7,
		}, 1000)

		if err != nil {
			return c.SendStatus(500)
		}
		return c.SendString(token)
	})

Using our go fiber custom middleware

Just below our generate-token route, we have used our middleware using app.Use() like below.

// we need to pass in a secret otherwise default secret is used
/*
* app.Use(jwtchecker.New(jwtchecker.Config{
*   Secret: "my_own_custom"  // in production please consider using env variables here
* }))
*/
app.Use(jwtchecker.New(jwtchecker.Config{}))

Although we are not required to pass a function handler to the Filter property of our Config, we can always do so if we want to skip our middleware based on a custom condition. To demonstrate that let’s say we want to skip jwt token verification for any request with an X-Secret-Pass having a value of “c5cacd9002a3“. We can do so as shown below.

app.Use(jwtchecker.New(jwtchecker.Config{
		Filter: func(c *fiber.Ctx) bool {

			// if X-Secret-Pass header matches then we skip the jwt validation
			return c.Get("X-Secret-Pass") == "c5cacd9002a3"

		},
	}))

If you remember from above when Filter returns True our middleware is skipped.

/protected

This route will be protected by our jwt middleware and the request will only pass if our token is a valid one. Otherwise, our middleware will not request to pass and return the Unauthorized status code.

app.Get("/protected", func(c *fiber.Ctx) error {
		claimData := c.Locals("jwtClaims")

		if claimData == nil {
			return c.SendString("Jwt was bypassed")
		}
		return c.JSON(claimData)
		
	})

Conclusion

Finally, we have reached the end of this tutorial, and hope you guys liked it. In case you have any doubts or questions please feel free to leave a comment down below otherwise show us your love by sharing it with your friends and on social media.

Leave a Reply

Back To Top