diff --git a/symbol_inject/cmd/symbol_inject.go b/symbol_inject/cmd/symbol_inject.go index 09f444564..1397b37b1 100644 --- a/symbol_inject/cmd/symbol_inject.go +++ b/symbol_inject/cmd/symbol_inject.go @@ -88,7 +88,7 @@ func main() { os.Exit(4) } - err = symbol_inject.InjectSymbol(file, w, *symbol, *value, *from) + err = symbol_inject.InjectStringSymbol(file, w, *symbol, *value, *from) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Remove(*output) diff --git a/symbol_inject/symbol_inject.go b/symbol_inject/symbol_inject.go index ebf05c871..2a3d67e6d 100644 --- a/symbol_inject/symbol_inject.go +++ b/symbol_inject/symbol_inject.go @@ -16,6 +16,7 @@ package symbol_inject import ( "bytes" + "encoding/binary" "fmt" "io" "math" @@ -50,7 +51,7 @@ func OpenFile(r io.ReaderAt) (*File, error) { return file, err } -func InjectSymbol(file *File, w io.Writer, symbol, value, from string) error { +func InjectStringSymbol(file *File, w io.Writer, symbol, value, from string) error { offset, size, err := findSymbol(file, symbol) if err != nil { return err @@ -75,13 +76,29 @@ func InjectSymbol(file *File, w io.Writer, symbol, value, from string) error { } } - return copyAndInject(file.r, w, offset, size, value) -} - -func copyAndInject(r io.ReaderAt, w io.Writer, offset, size uint64, value string) (err error) { buf := make([]byte, size) copy(buf, value) + return copyAndInject(file.r, w, offset, buf) +} + +func InjectUint64Symbol(file *File, w io.Writer, symbol string, value uint64) error { + offset, size, err := findSymbol(file, symbol) + if err != nil { + return err + } + + if size != 8 { + return fmt.Errorf("symbol %q is not a uint64, it is %d bytes long", symbol, size) + } + + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, value) + + return copyAndInject(file.r, w, offset, buf) +} + +func copyAndInject(r io.ReaderAt, w io.Writer, offset uint64, buf []byte) (err error) { // Copy the first bytes up to the symbol offset _, err = io.Copy(w, io.NewSectionReader(r, 0, int64(offset))) @@ -91,7 +108,7 @@ func copyAndInject(r io.ReaderAt, w io.Writer, offset, size uint64, value string } // Write the remainder of the file - pos := int64(offset + size) + pos := int64(offset) + int64(len(buf)) if err == nil { _, err = io.Copy(w, io.NewSectionReader(r, pos, 1<<63-1-pos)) } diff --git a/symbol_inject/symbol_inject_test.go b/symbol_inject/symbol_inject_test.go index 77ec7d942..6607e6599 100644 --- a/symbol_inject/symbol_inject_test.go +++ b/symbol_inject/symbol_inject_test.go @@ -23,32 +23,23 @@ import ( func TestCopyAndInject(t *testing.T) { s := "abcdefghijklmnopqrstuvwxyz" testCases := []struct { - offset, size uint64 - value string - expected string + offset uint64 + buf string + expected string }{ { offset: 0, - size: 1, - value: "A", + buf: "A", expected: "Abcdefghijklmnopqrstuvwxyz", }, { offset: 1, - size: 1, - value: "B", - expected: "aBcdefghijklmnopqrstuvwxyz", - }, - { - offset: 1, - size: 1, - value: "BCD", + buf: "B", expected: "aBcdefghijklmnopqrstuvwxyz", }, { offset: 25, - size: 1, - value: "Z", + buf: "Z", expected: "abcdefghijklmnopqrstuvwxyZ", }, } @@ -57,7 +48,7 @@ func TestCopyAndInject(t *testing.T) { t.Run(strconv.Itoa(i), func(t *testing.T) { in := bytes.NewReader([]byte(s)) out := &bytes.Buffer{} - copyAndInject(in, out, testCase.offset, testCase.size, testCase.value) + copyAndInject(in, out, testCase.offset, []byte(testCase.buf)) if out.String() != testCase.expected { t.Errorf("expected %s, got %s", testCase.expected, out.String())