Skip to content

Instantly share code, notes, and snippets.

@WalBeh
Last active October 3, 2024 07:53
Show Gist options
  • Save WalBeh/cca773617dc86f5b6171ebae93654ed9 to your computer and use it in GitHub Desktop.
Save WalBeh/cca773617dc86f5b6171ebae93654ed9 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"syscall"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)
func sendSQLStatement(proto string, stmt string) {
// Create the JSON payload
payload := map[string]string{"stmt": stmt}
payloadBytes, err := json.Marshal(payload)
if err != nil {
log.Fatalf("Failed to marshal JSON payload: %v", err)
}
log.Printf("Payload: %s", string(payloadBytes))
// Create an HTTP client with TLS configuration to skip certificate verification
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
// Make the HTTP POST request
url := proto + "://127.0.0.1:4200/_sql"
resp, err := httpClient.Post(url, "application/json", bytes.NewBuffer(payloadBytes))
if err != nil {
log.Fatalf("Failed to make HTTP POST request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Fatalf("HTTP request failed with status: %s", resp.Status)
}
// Read and print the response body
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
}
log.Printf("Response from server: %s", string(respBody))
}
func main() {
var config *rest.Config
var err error
var kubeconfig clientcmd.ClientConfig
var namespace string
var replicas int
log.SetPrefix("Decommissioner: ")
log.SetOutput(os.Stdout)
// Get the HOSTNAME environment variable
envHostname := os.Getenv("HOSTNAME")
// Define a command-line flag for the timeout value with a default of "7200s"
crateNodePrefix := flag.String("crate-node-prefix", "data-hot", "Prefix of the CrateDB node name")
decommissionTimeout := flag.String("timeout", "7200s", "Timeout for decommission statemtment")
pid := flag.Int("pid", 1, "PID of the process to check")
proto := flag.String("proto", "https", "Protocol to use for the HTTP server")
hostname := flag.String("hostname", envHostname, "Hostname of the pod")
scaledto := flag.Int("scaled-to", 0, "Number of replicas")
min_availability := flag.String("min-availability", "full", "Minimum availability during decommission")
flag.Parse()
// Determine if we are running in-cluster or using kubeconfig
kubeconfigPath := os.Getenv("KUBECONFIG")
inClusterConfig := kubeconfigPath == ""
if inClusterConfig {
log.Println("Using in-cluster configuration")
config, err = rest.InClusterConfig()
if err != nil {
log.Fatalf("Failed to create in-cluster config: %v", err)
}
namespaceBytes, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
if err != nil {
log.Fatalf("Failed to read namespace: %v", err)
}
namespace = string(namespaceBytes)
} else {
log.Printf("Using kubeconfig from %s", kubeconfigPath)
kubeconfig = clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
clientcmd.NewDefaultClientConfigLoadingRules(),
&clientcmd.ConfigOverrides{},
)
config, err = kubeconfig.ClientConfig()
if err != nil {
log.Fatalf("Failed to load kubeconfig: %v", err)
}
namespace, _, err = kubeconfig.Namespace()
if err != nil {
log.Fatalf("Failed to get namespace from kubeconfig: %v", err)
}
}
// Create a Kubernetes client
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("Failed to create Kubernetes client: %v", err)
}
if inClusterConfig {
namespaceBytes, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
if err != nil {
log.Fatalf("Failed to read namespace: %v", err)
}
namespace = string(namespaceBytes)
} else {
// Get the current namespace from the kubeconfig context
namespace, _, err = kubeconfig.Namespace()
if err != nil {
log.Fatalf("Failed to get namespace from kubeconfig: %v", err)
}
}
// Construct the StatefulSet name from HOSTNAME
parts := strings.Split(*hostname, "-")
if len(parts) < 2 {
log.Fatalf("Invalid HOSTNAME format")
}
// Get the StatefulSet
//TODO: We do not have the permission to get the StatefulSet object in cluster
if inClusterConfig {
replicas = *scaledto
log.Printf("Mocking number of replicas to %d", *scaledto)
} else {
log.Printf("Running out-of-cluster")
statefulSetName := strings.Join(parts[:len(parts)-1], "-")
statefulSet, err := clientset.AppsV1().StatefulSets(string(namespace)).Get(context.TODO(), statefulSetName, metav1.GetOptions{})
if err != nil {
log.Fatalf("Failed to get StatefulSet: %v", err)
}
replicas = int(*statefulSet.Spec.Replicas)
}
// Check the number of replicas configured
log.Printf("StatefulSet has %d replicas configured\n", replicas)
// If replicas > 0, we are probably doing a rolling restart and need to decommission the node
// instead of just stopping the process by kubelet with SIGTERM
time.Sleep(2 * time.Second) //TODO: Sleep for 2 seconds to allow scale down to settle (?) - not sure if this is needed
if replicas > 0 {
// Extract the last part of the hostname - which should be the pod number
hostnameParts := strings.Split(*hostname, "-")
if len(hostnameParts) < 2 {
log.Fatalf("Invalid HOSTNAME format")
}
podNumber := hostnameParts[len(hostnameParts)-1]
// Send the SQL statement to decommission the node
log.Printf("Decommissioning node %s with graceful_stop.timeout of %s", podNumber, *decommissionTimeout)
stmt := fmt.Sprintf(`set global transient "cluster.graceful_stop.timeout" = '%s';`, *decommissionTimeout)
sendSQLStatement(*proto, stmt)
stmt = fmt.Sprintf(`set global transient "cluster.graceful_stop.force" = True;`)
sendSQLStatement(*proto, stmt)
stmt = fmt.Sprintf(`set global transient "cluster.graceful_stop.min_availability"='%s';`, *min_availability)
sendSQLStatement(*proto, stmt)
stmt = fmt.Sprintf(`alter cluster decommission '%s-%s'`, *crateNodePrefix, podNumber)
sendSQLStatement(*proto, stmt)
log.Println("Decommission command sent successfully")
// Function to check if a process with a given PID is running
isProcessRunning := func(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
// Sending signal 0 to a process is a way to check if it is running
err = process.Signal(syscall.Signal(0))
return err == nil
}
// Loop to check if the process is running, which means the crate is still running
counter := 0
for isProcessRunning(*pid) {
if counter%10 == 0 || counter == 0 {
log.Printf("Process %d is still running (check count: %d)", *pid, counter)
}
counter++
time.Sleep(2 * time.Second)
}
log.Printf("Process %d has stopped", *pid)
} else {
log.Printf("No replicas are configured -- Skipping decommission")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment