Skip to content

Commit

Permalink
Added capability for only the whitelisted ips to connect to the subsc…
Browse files Browse the repository at this point in the history
…riber
  • Loading branch information
kpachhai committed Dec 2, 2024
1 parent 59530a5 commit 5c86ecf
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 3 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ DB_USER=postgres
DB_PASSWORD=postgres
DB_NAME=nuklaivm
DB_SSLMODE=require # Or "disable" if you don't want to use SSL
GRPC_WHITELISTED_BLOCKCHAIN_NODES="192.168.1.100" # "127.0.0.1,localhost,::1" is already included by default. You can even include something like myblockchain.aws.com
59 changes: 59 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package config

import (
"fmt"
"net"
"net/url"
"os"
"strings"
)

// GetDatabaseURL retrieves the database connection string
Expand All @@ -26,6 +28,29 @@ func GetDatabaseURL() string {
)
}

// GetWhitelistIPs retrieves the list of whitelisted IPs from the environment variable
// and resolves domain names to IPs.
func GetWhitelistIPs() []string {
ipList := getEnv("GRPC_WHITELISTED_BLOCKCHAIN_NODES", "127.0.0.1,localhost,::1")
entries := strings.Split(ipList, ",") // Split by comma

whitelist := []string{}
defaultIPs := []string{"127.0.0.1", "localhost", "::1"} // Always include these

// Resolve domain names and add IPs to the whitelist
for _, entry := range append(defaultIPs, entries...) {
entry = strings.TrimSpace(entry)
ips, err := resolveToIPs(entry)
if err != nil {
// Log the error and skip unresolved entries
continue
}
whitelist = append(whitelist, ips...)
}

return uniqueStrings(whitelist) // Ensure no duplicates
}

// getEnv retrieves the value of the environment variable named by the key.
// If the variable is not present, it returns the default value.
func getEnv(key, defaultValue string) string {
Expand All @@ -34,3 +59,37 @@ func getEnv(key, defaultValue string) string {
}
return defaultValue
}

// resolveToIPs resolves a domain name to its IP addresses or directly returns the IP if it's already valid
func resolveToIPs(host string) ([]string, error) {
// Check if the host is already a valid IP
if net.ParseIP(host) != nil {
return []string{host}, nil
}

// Attempt to resolve the domain name
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}

// Convert net.IP to strings
ipStrings := []string{}
for _, ip := range ips {
ipStrings = append(ipStrings, ip.String())
}
return ipStrings, nil
}

// uniqueStrings removes duplicates from a slice of strings
func uniqueStrings(slice []string) []string {
unique := make(map[string]bool)
result := []string{}
for _, item := range slice {
if _, exists := unique[item]; !exists {
unique[item] = true
result = append(result, item)
}
}
return result
}
9 changes: 8 additions & 1 deletion grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type Server struct {

// startGRPCServer starts the gRPC server for receiving block data
func StartGRPCServer(db *sql.DB, port string) {
// Load the whitelist
LoadWhitelist()

// Ensure the port has a colon prefix
if !strings.HasPrefix(port, ":") {
port = ":" + port
Expand All @@ -44,7 +47,11 @@ func StartGRPCServer(db *sql.DB, port string) {
}

// Use insecure credentials to allow plaintext communication
grpcServer := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
serverOptions := []grpc.ServerOption{
grpc.Creds(insecure.NewCredentials()),
grpc.UnaryInterceptor(UnaryInterceptor), // Attach the interceptor
}
grpcServer := grpc.NewServer(serverOptions...)

// Register your ExternalSubscriber service
pb.RegisterExternalSubscriberServer(grpcServer, &Server{db: db})
Expand Down
46 changes: 46 additions & 0 deletions grpc/whitelist.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package grpc

import (
"context"
"fmt"
"log"
"strings"

"github.com/nuklai/nuklaivm-external-subscriber/config"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)

var WhitelistedIPs = make(map[string]bool)

// LoadWhitelist loads the whitelist using the config package
func LoadWhitelist() {
ips := config.GetWhitelistIPs()
if len(ips) == 0 {
log.Println("No whitelisted IPs provided. The gRPC server will reject all connections.")
return
}

// Populate the whitelist map
for _, ip := range ips {
WhitelistedIPs[ip] = true
}

log.Printf("Loaded whitelisted IPs: %v\n", WhitelistedIPs)
}

// UnaryInterceptor checks the IP of the client and allows/denies the connection
func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
peerInfo, ok := peer.FromContext(ctx)
if !ok {
return nil, fmt.Errorf("could not retrieve peer info")
}

clientIP := strings.Split(peerInfo.Addr.String(), ":")[0] // Extract IP address
if !WhitelistedIPs[clientIP] {
log.Printf("Unauthorized connection attempt from IP: %s", clientIP)
return nil, fmt.Errorf("unauthorized IP: %s", clientIP)
}

return handler(ctx, req)
}
8 changes: 6 additions & 2 deletions infra/aws/task-definition-subscriber.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"cpu": "${TASK_CPU}",
"memory": "${TASK_MEMORY}",
"cpu": "${TASK_CPU}",
"memory": "${TASK_MEMORY}",
"volumes": [],
"networkMode": "awsvpc",
"family": "${PRODUCT}-${APPLICATION}-${ENVIRONMENT}",
Expand Down Expand Up @@ -66,6 +66,10 @@
{
"name": "DB_PASSWORD",
"valueFrom": "arn:aws:ssm:${AWS_REGION}:${AWS_ACCOUNT_ID}:parameter/${ENVIRONMENT}/${PRODUCT}/${APPLICATION}/db_password"
},
{
"name": "GRPC_WHITELISTED_BLOCKCHAIN_NODES",
"valueFrom": "arn:aws:ssm:${AWS_REGION}:${AWS_ACCOUNT_ID}:parameter/${ENVIRONMENT}/${PRODUCT}/${APPLICATION}/grpc_whitelisted_blockchain_nodes"
}
],
"logConfiguration": {
Expand Down

0 comments on commit 5c86ecf

Please sign in to comment.