You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

402 lines
9.3 KiB
Go

package main
import (
"context"
"fmt"
"log"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
"tzgyn/kon-redirect/db"
"math/rand"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/limiter"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/joho/godotenv"
"github.com/mileusna/useragent"
"github.com/oschwald/geoip2-golang"
"github.com/patrickmn/go-cache"
"github.com/robfig/cron/v3"
"github.com/ua-parser/uap-go/uaparser"
)
// Client encapsulates the S3 SDK presign client and provides methods to presign requests.
type Client struct {
PresignClient *s3.PresignClient
}
// GetObject makes a presigned request that can be used to get an object from a bucket.
func (p *Client) GetObject(
ctx context.Context,
bucket string,
key string,
expiry time.Duration,
) (*v4.PresignedHTTPRequest, error) {
request, err := p.PresignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}, func(opts *s3.PresignOptions) {
opts.Expires = expiry
})
if err != nil {
log.Printf("Couldn't get a presigned request to get %v:%v. Here's why: %v\n",
bucket, key, err)
}
return request, err
}
func main() {
err := godotenv.Load()
if err != nil {
fmt.Println("No .env found")
}
parser, err := uaparser.New("./regexes.yaml")
if err != nil {
log.Fatal(err)
}
fmt.Println("Initializing GeoLite2 DB...")
_, err = os.Stat("./data/GeoLite2-City.mmdb")
if err != nil {
fmt.Println("GeoLite2 DB Not Found...")
err := downloadAndExtractDB()
if err != nil {
log.Fatal(err)
}
} else {
fmt.Println("GeoLite2 DB Found, Skipping Download...")
}
geodb, err := geoip2.Open("./data/GeoLite2-City.mmdb")
if err != nil {
log.Fatal("Failed to initialize GeoLite2 DB")
}
fmt.Println("Finished initializing GeoLite2 DB")
fmt.Println("Initializing postgres DB and cache...")
ctx := context.Background()
c := cache.New(1*time.Hour, 5*time.Minute)
c.Set("key", "value", cache.DefaultExpiration)
dbUrl := os.Getenv("DATABASE_URL")
if len(dbUrl) == 0 {
log.Fatal("DATABASE_URL not found")
}
pgconfig, err := pgxpool.ParseConfig(dbUrl)
if err != nil {
log.Fatal(err)
}
pgconfig.MinConns = 4
conn, err := pgxpool.NewWithConfig(ctx, pgconfig)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
fmt.Println("Finished initializing postgres DB and cache")
fmt.Println("Initializing S3...")
awsregion := os.Getenv("AWS_REGION")
if len(awsregion) == 0 {
log.Fatal("AWS_REGION not found")
}
s3endpoint := os.Getenv("AWS_ENDPOINT_URL_S3")
if len(s3endpoint) == 0 {
log.Fatal("AWS_ENDPOINT_URL_S3 not found")
}
s3bucket := os.Getenv("AWS_BUCKET_NAME")
if len(s3bucket) == 0 {
log.Fatal("AWS_BUCKET_NAME not found")
}
sdkConfig, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
log.Printf("Couldn't load default configuration. Here's why: %v\n", err)
return
}
svc := s3.NewFromConfig(sdkConfig, func(o *s3.Options) {
o.BaseEndpoint = aws.String(s3endpoint)
o.Region = awsregion
})
// Presigning a request
ps := s3.NewPresignClient(svc)
presigner := &Client{PresignClient: ps}
fmt.Println("Finished initializing S3")
cache_client := cache.New(1*time.Hour, 5*time.Minute)
app := fiber.New()
app.Use(logger.New())
app.Use(limiter.New(limiter.Config{
Max: 20,
Expiration: 1 * time.Minute,
}))
app.Static("/", "./public")
fallbackurl := os.Getenv("FALLBACK_URL")
if len(fallbackurl) == 0 {
fallbackurl = "https://app.kon.sh"
}
appurl := os.Getenv("APP_URL")
if len(appurl) == 0 {
appurl = "kon.sh"
}
invalidUrl := fallbackurl + "/qr/invalid"
app.Get("/", func(c *fiber.Ctx) error {
return c.Redirect(fallbackurl, 301)
})
app.Get("/:code", func(c *fiber.Ctx) error {
code := c.Params("code")
domain := c.Hostname()
uastrings := c.GetReqHeaders()["User-Agent"]
redirecturl := ""
var shortenerId string
iosEnabled := false
iosRedirectUrl := ""
androidEnabled := false
androidRedirectUrl := ""
queries := db.New(conn)
if domain == appurl {
shortener, err := queries.GetShortener(ctx, code)
if err != nil {
return c.Redirect(invalidUrl)
}
shortenerId = shortener.ID
iosEnabled = shortener.Ios
androidEnabled = shortener.Android
if shortener.IsFileUpload {
// Presigned URL to download an object from the bucket
presignedGetReq, err := presigner.GetObject(context.TODO(), s3bucket, shortener.FilePath.String, 65*time.Minute)
if err != nil {
log.Printf("Couldn't get a presigned request to get %s. Here's why: %v\n", shortener.FilePath.String, err)
return c.Redirect(invalidUrl)
}
iosRedirectUrl = presignedGetReq.URL
androidRedirectUrl = presignedGetReq.URL
redirecturl = presignedGetReq.URL
} else {
iosRedirectUrl = shortener.IosLink.String
androidRedirectUrl = shortener.AndroidLink.String
redirecturl = shortener.Link
}
} else {
shortener, err := queries.GetShortenerWithDomain(ctx, db.GetShortenerWithDomainParams{
Code: code,
CustomDomain: pgtype.Text{String: domain, Valid: true},
})
if err != nil {
return c.Redirect(invalidUrl)
}
shortenerId = shortener.ID
iosEnabled = shortener.Ios
androidEnabled = shortener.Android
if shortener.IsFileUpload {
// Presigned URL to download an object from the bucket
presignedGetReq, err := presigner.GetObject(context.TODO(), s3bucket, shortener.FilePath.String, 10*time.Minute)
if err != nil {
log.Printf("Couldn't get a presigned request to get %s. Here's why: %v\n", shortener.FilePath.String, err)
return c.Redirect(invalidUrl)
}
iosRedirectUrl = presignedGetReq.URL
androidRedirectUrl = presignedGetReq.URL
redirecturl = presignedGetReq.URL
} else {
iosRedirectUrl = shortener.IosLink.String
androidRedirectUrl = shortener.AndroidLink.String
redirecturl = shortener.Link
}
}
if len(uastrings) == 0 {
return c.Redirect(redirecturl)
}
uastring := uastrings[0]
ua := useragent.Parse(uastring)
client := parser.Parse(uastring)
finalRedirectUrl := redirecturl
if ua.OS == "iOS" && iosEnabled && len(iosRedirectUrl) != 0 {
finalRedirectUrl = iosRedirectUrl
} else if ua.OS == "Android" && androidEnabled && len(androidRedirectUrl) != 0 {
finalRedirectUrl = androidRedirectUrl
}
ipList := c.IPs()
if len(ipList) == 0 {
return c.Redirect(finalRedirectUrl)
}
ip := ipList[0]
_, found := cache_client.Get(ip + "_" + string(shortenerId))
if found {
return c.Redirect(finalRedirectUrl)
}
cache_client.Set(ip+"_"+string(shortenerId), true, cache.DefaultExpiration)
if ua.Bot {
return c.Redirect(finalRedirectUrl)
}
record, err := getCity(geodb, ip)
if err != nil {
return c.Redirect(finalRedirectUrl)
}
devicetype := ""
if ua.Mobile {
devicetype = "mobile"
} else if ua.Tablet {
devicetype = "tablet"
} else if ua.Desktop {
devicetype = "desktop"
}
referers := c.GetReqHeaders()["Referer"]
referer := ""
if len(referers) > 0 {
referer = referers[0]
}
referer = getDomainWithoutWWW(referer)
err = queries.CreateVisitor(ctx, db.CreateVisitorParams{
ID: generateRandomString(12),
ShortenerID: shortenerId,
DeviceType: devicetype,
DeviceVendor: client.Device.Brand,
Browser: client.UserAgent.Family,
Os: client.Os.Family,
Country: record.Country.Names["en"],
CountryCode: record.Country.IsoCode,
City: record.City.Names["en"],
Referer: referer,
})
if err != nil {
return c.Redirect(finalRedirectUrl)
}
return c.Redirect(finalRedirectUrl)
})
cron := cron.New()
cron.AddFunc("@weekly", func() {
fmt.Println("Updating GeoLite2 DB...")
err := downloadAndExtractDB()
if err != nil {
log.Fatal(err)
}
fmt.Println("Finished updating GeoLite2 DB")
})
cron.Start()
done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
go func() {
<-done
fmt.Println("Gracefully shutting down...")
_ = app.Shutdown()
}()
if err := app.Listen(":3000"); err != nil {
log.Panic(err)
}
fmt.Println("Running cleanup tasks...")
fmt.Println("Stopping cronjobs...")
cron.Stop()
geodb.Close()
conn.Close()
}
func getDomainWithoutWWW(str string) string {
domain, err := url.ParseRequestURI(str)
if err == nil {
hostname := domain.Hostname()
if strings.HasPrefix(hostname, "www.") {
result, ok := strings.CutPrefix(hostname, "www.")
if ok {
hostname = result
}
}
return hostname
}
if strings.ContainsAny(str, ".") && !strings.ContainsAny(str, " ") {
domain, err = url.ParseRequestURI("https://" + str)
if err == nil {
hostname := domain.Hostname()
if strings.HasPrefix(hostname, "www.") {
result, ok := strings.CutPrefix(hostname, "www.")
if ok {
hostname = result
}
}
return hostname
}
}
return ""
}
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seed := rand.NewSource(time.Now().UnixNano())
random := rand.New(seed)
result := make([]byte, length)
for i := range result {
result[i] = charset[random.Intn(len(charset))]
}
return string(result)
}