aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/data/data-handler.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/data/data-handler.go')
-rw-r--r--pkg/data/data-handler.go179
1 files changed, 111 insertions, 68 deletions
diff --git a/pkg/data/data-handler.go b/pkg/data/data-handler.go
index b571010..673f247 100644
--- a/pkg/data/data-handler.go
+++ b/pkg/data/data-handler.go
@@ -21,8 +21,10 @@ package data
import (
"context"
"encoding/json"
+ "fmt"
"github.com/google/uuid"
openapi_types "github.com/oapi-codegen/runtime/types"
+ "github.com/open-policy-agent/opa/storage"
"net/http"
"path/filepath"
"policy-opa-pdp/consts"
@@ -30,11 +32,9 @@ import (
"policy-opa-pdp/pkg/metrics"
"policy-opa-pdp/pkg/model/oapicodegen"
"policy-opa-pdp/pkg/opasdk"
+ "policy-opa-pdp/pkg/policymap"
"policy-opa-pdp/pkg/utils"
"strings"
-
- "github.com/open-policy-agent/opa/storage"
- "policy-opa-pdp/pkg/policymap"
)
var (
@@ -78,60 +78,30 @@ func createOPADataUpdateExceptionResponse(statusCode int, errorMessage string, p
}
}
-// Validate OPADataUpdateRequest function
-func validateOPADataUpdateRequest(request *oapicodegen.OPADataUpdateRequest) []string {
- var validationErrors []string
-
- // Check if required fields are populated
- dateString := (request.CurrentDate).String()
- if !(utils.IsValidCurrentDate(&dateString)) {
- validationErrors = append(validationErrors, "CurrentDate is required")
- }
-
- // Validate CurrentDateTime format
- if !(utils.IsValidTime(request.CurrentDateTime)) {
- validationErrors = append(validationErrors, "CurrentDateTime is invalid or missing")
- }
-
- // Validate CurrentTime format
- if !(utils.IsValidCurrentTime(request.CurrentTime)) {
- validationErrors = append(validationErrors, "CurrentTime is invalid or missing")
- }
-
- // Validate Data field (ensure it's not nil and has items)
- if !(utils.IsValidData(request.Data)) {
- validationErrors = append(validationErrors, "Data is required and cannot be empty")
- }
-
- // Validate TimeOffset format (e.g., +02:00 or -05:00)
- if !(utils.IsValidTimeOffset(request.TimeOffset)) {
- validationErrors = append(validationErrors, "TimeOffset is invalid or missing")
- }
-
- // Validate TimeZone format (e.g., 'America/New_York')
- if !(utils.IsValidTimeZone(request.TimeZone)) {
- validationErrors = append(validationErrors, "TimeZone is invalid or missing")
- }
-
- // Optionally, check if 'OnapComponent', 'OnapInstance', 'OnapName', and 'PolicyName' are provided
- if !(utils.IsValidString(request.OnapComponent)) {
- validationErrors = append(validationErrors, "OnapComponent is required")
- }
+type Policy struct {
+ Data []string `json:"data"`
+ Policy []string `json:"policy"`
+ PolicyID string `json:"policy-id"`
+ PolicyVersion string `json:"policy-version"`
+}
- if !(utils.IsValidString(request.OnapInstance)) {
- validationErrors = append(validationErrors, "OnapInstance is required")
+// Function to extract the policy by policyId
+func getPolicyByID(policiesMap string, policyId string) (*Policy, error) {
+ var policies struct {
+ DeployedPolicies []Policy `json:"deployed_policies_dict"`
}
- if !(utils.IsValidString(request.OnapName)) {
- validationErrors = append(validationErrors, "OnapName is required")
+ if err := json.Unmarshal([]byte(policiesMap), &policies); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal policies: %v", err)
}
- if !(utils.IsValidString(request.PolicyName)) {
- validationErrors = append(validationErrors, "PolicyName is required and cannot be empty")
+ for _, policy := range policies.DeployedPolicies {
+ if policy.PolicyID == policyId {
+ return &policy, nil
+ }
}
- // Return all validation errors (if any)
- return validationErrors
+ return nil, fmt.Errorf("policy '%s' not found", policyId)
}
func patchHandler(res http.ResponseWriter, req *http.Request) {
@@ -144,13 +114,18 @@ func patchHandler(res http.ResponseWriter, req *http.Request) {
log.Errorf(errMsg)
return
}
- path := strings.TrimPrefix(req.URL.Path, "/policy/pdpo/v1/data/")
+ path := strings.TrimPrefix(req.URL.Path, "/policy/pdpo/v1/data")
dirParts := strings.Split(path, "/")
dataDir := filepath.Join(dirParts...)
log.Infof("dataDir : %s", dataDir)
// Validate the request
- validationErrors := validateOPADataUpdateRequest(&requestBody)
+ validationErrors := utils.ValidateOPADataRequest(&requestBody)
+
+ // Validate Data field (ensure it's not nil and has items)
+ if !(utils.IsValidData(requestBody.Data)) {
+ validationErrors = append(validationErrors, "Data is required and cannot be empty")
+ }
// Print validation errors
if len(validationErrors) > 0 {
@@ -159,7 +134,7 @@ func patchHandler(res http.ResponseWriter, req *http.Request) {
sendErrorResponse(res, errMsg, http.StatusBadRequest)
return
} else {
- log.Errorf("All fields are valid!")
+ log.Debug("All fields are valid!")
// Access the data part
data := requestBody.Data
log.Infof("data : %s", data)
@@ -172,10 +147,46 @@ func patchHandler(res http.ResponseWriter, req *http.Request) {
log.Errorf(errMsg)
return
}
+
+ // Checking if the data operation is performed for a deployed policy with policymap.CheckIfPolicyAlreadyExists and getPolicyByID
+ // if a match is found, we will join the url path with dots and check for the data key from the policiesMap whether utl path is a
+ // prefix of data key. we will proceed for Patch Operation if this matches, else return error
+ if len(dirParts) > 0 && dirParts[0] == "" {
+ dirParts = dirParts[1:]
+ }
+ finalDirParts := strings.Join(dirParts, ".")
+
+ policiesMap := policymap.LastDeployedPolicies
+
+ matchedPolicy, err := getPolicyByID(policiesMap, *policyId)
+ if err != nil {
+ sendErrorResponse(res, err.Error(), http.StatusBadRequest)
+ log.Errorf(err.Error())
+ return
+ }
+
+ log.Infof("Matched policy: %+v", matchedPolicy)
+
+ // Check if finalDirParts starts with any data key
+ matchFound := false
+ for _, dataKey := range matchedPolicy.Data {
+ if strings.HasPrefix(finalDirParts, dataKey) {
+ matchFound = true
+ break
+ }
+ }
+
+ if !matchFound {
+ errMsg := fmt.Sprintf("Dynamic Data add/replace/remove for policy '%s' expected under url path '%v'", *policyId, matchedPolicy.Data)
+ sendErrorResponse(res, errMsg, http.StatusBadRequest)
+ log.Errorf(errMsg)
+ return
+ }
+
if err := patchData(dataDir, data, res); err != nil {
- // Handle the error, for example, log it or return an appropriate response
- log.Errorf("Error encoding JSON response: %s", err)
- }
+ // Handle the error, for example, log it or return an appropriate response
+ log.Errorf("Error encoding JSON response: %s", err)
+ }
}
}
@@ -215,7 +226,7 @@ func extractPatchInfo(res http.ResponseWriter, ops *[]map[string]interface{}, ro
// PATCH request with add or replace opType, MUST contain a "value" member whose content specifies the value to be added / replaced. For remove opType, value does not required
if optypeString == "add" || optypeString == "replace" {
value, valueErr = op["value"]
- if !valueErr {
+ if !valueErr || isEmpty(value) {
valueErrMsg := "Error in getting data value. Value is not given in request body"
sendErrorResponse(res, valueErrMsg, http.StatusInternalServerError)
log.Errorf(valueErrMsg)
@@ -225,7 +236,7 @@ func extractPatchInfo(res http.ResponseWriter, ops *[]map[string]interface{}, ro
impl.Value = value
opPath, opPathErr := op["path"].(string)
- if !opPathErr {
+ if !opPathErr || len(opPath) == 0 {
opPathErrMsg := "Error in getting data path. Path is not given in request body"
sendErrorResponse(res, opPathErrMsg, http.StatusInternalServerError)
log.Errorf(opPathErrMsg)
@@ -243,6 +254,35 @@ func extractPatchInfo(res http.ResponseWriter, ops *[]map[string]interface{}, ro
return result
}
+func isEmpty(data interface{}) bool {
+ if data == nil {
+ return true // Nil values are considered empty
+ }
+
+ switch v := data.(type) {
+ case string:
+ return len(v) == 0 // Check if string is empty
+ case []interface{}:
+ return len(v) == 0 // Check if slice is empty
+ case map[string]interface{}:
+ return len(v) == 0 // Check if map is empty
+ case []byte:
+ return len(v) == 0 // Check if byte slice is empty
+ case int, int8, int16, int32, int64:
+ return v == 0 // Zero integers are considered empty
+ case uint, uint8, uint16, uint32, uint64:
+ return v == 0 // Zero unsigned integers are considered empty
+ case float32, float64:
+ return v == 0.0 // Zero floats are considered empty
+ case bool:
+ return !v // `false` is considered empty
+ case nil:
+ return true // Explicitly checking nil again
+ default:
+ return false // Other data types are not considered empty
+ }
+}
+
func constructPath(opPath string, opType string, root string, res http.ResponseWriter) (storagePath storage.Path) {
// Construct patch path.
log.Debugf("root: %s", root)
@@ -269,13 +309,10 @@ func constructPath(opPath string, opType string, root string, res http.ResponseW
path = root + "/" + path
}
} else {
- if opType == "remove" {
- valueErrMsg := "Error in getting data path - Invalid path (/) is used."
- sendErrorResponse(res, valueErrMsg, http.StatusInternalServerError)
- log.Errorf(valueErrMsg)
- return nil
- }
- path = root
+ valueErrMsg := "Error in getting data path - Invalid path (/) is used."
+ sendErrorResponse(res, valueErrMsg, http.StatusInternalServerError)
+ log.Errorf(valueErrMsg)
+ return nil
}
log.Infof("calling ParsePatchPathEscaped to check the path")
@@ -386,7 +423,13 @@ func getDataInfo(res http.ResponseWriter, req *http.Request) {
constructResponseHeader(res, req)
urlPath := req.URL.Path
- dataPath := strings.ReplaceAll(urlPath, "/policy/pdpo/v1/data", "")
+
+ dataPath := strings.TrimPrefix(urlPath, "/policy/pdpo/v1/data")
+
+ if len(strings.TrimSpace(dataPath)) == 0 {
+ // dataPath "/" is used to get entire data
+ dataPath = "/"
+ }
log.Debugf("datapath to get Data : %s\n", dataPath)
getData(res, dataPath)
@@ -420,7 +463,7 @@ func getData(res http.ResponseWriter, dataPath string) {
res.WriteHeader(http.StatusOK)
if err := json.NewEncoder(res).Encode(dataResponse); err != nil {
- // Handle the error, for example, log it or return an appropriate response
- log.Errorf("Error encoding JSON response: %s", err)
+ // Handle the error, for example, log it or return an appropriate response
+ log.Errorf("Error encoding JSON response: %s", err)
}
}